In [None]:
import torch
from torch.utils import data
from tqdm import tqdm
import numpy as np
from datetime import datetime
import json
import wandb
import math
import os

from models import VGG16SiameseNetwork, ResNet18SiameseNetwork, InceptionV3SiameseNetwork, TripletLoss

import sys
sys.path.append("../")
from dataset_utils import ImageNet_Triplet

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if device.type == "cuda":
    print(torch.cuda.get_device_name())

### Konstanten

In [3]:
ARCHITECTURE = "resnet18"
DATASET = "imagenet"
# IMAGE_NUM = 9600
BATCH_SIZE = 64
EPOCHS = 50
MARGIN = 100
LEARNING_RATE = 0.0001 # 0.0001
OUTPUT_DIM = 1000
# CLASSES = [
#     "tabby",
#     "tiger cat",
#     "Persian cat",
#     "Egyptian cat",
#     "cougar",
#     "lynx",
#     "polecat",
#     "tiger",
#     "tractor",
#     "goldfinch",
#     "warplane",
#     "garbage truck",
#     "grand piano",
#     "German shepherd",
#     "French bulldog",
#     "Eskimo dog",
# ]
OPTIMIZER = "adam"
MOMENTUM = None
WEIGHT_DECAY = None
match OPTIMIZER:
    case "sgd":
        MOMENTUM = 0.9
        WEIGHT_DECAY = 1e-4

SCHEDULER = "custom_cosine_annealing_without_restart"
SCHEDULER_MODULO = 1
GAMMA = None
match SCHEDULER:
    case "exponentiallr":
        GAMMA = 0.1
        SCHEDULER_MODULO = 5
    case "custom_cosine_annealing_without_restart":
        pass

In [None]:
wandb.login()

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="siamses_network",
    dir=r"C:\Users\mariu\Desktop",

    # track hyperparameters and run metadata
    config={
        "architecture": ARCHITECTURE,
        "dataset": DATASET,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "margin": MARGIN,
        "learning_rate": LEARNING_RATE,
        "output_dim": OUTPUT_DIM,
        "optimizer": OPTIMIZER,
        "optimizer_momentum": MOMENTUM,
        "optimizer_weight_decay": WEIGHT_DECAY,
        "scheduler": SCHEDULER,
        "scheduler_modulo": SCHEDULER_MODULO,
        "scheduler_gamma": GAMMA
    }
)

### Datensatz & Teildatensatz

In [6]:
NUM_WORKERS = 4

match DATASET:
    case "imagenet":
        dataset = ImageNet_Triplet(root=r"C:\Users\mariu\Documents\Development\Datasets\ImageNet_Subset")
        dataset_val = ImageNet_Triplet(root=r"C:\Users\mariu\Documents\Development\Datasets\ImageNet_Subset", split="val")

data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
data_loader_val = data.DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# dataset = ImageNet_Triplet(classes=CLASSES)
# subset, _ = dataset.getSubset(IMAGE_NUM)
# data_loader = data.DataLoader(subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

### Modell

In [7]:
match ARCHITECTURE:
    case "vgg16":
        siamese_network = VGG16SiameseNetwork(output_dim=OUTPUT_DIM)
    case "resnet18":
        siamese_network = ResNet18SiameseNetwork(output_dim=OUTPUT_DIM)
    case "inceptionv3":
        siamese_network = InceptionV3SiameseNetwork(output_dim=OUTPUT_DIM)

siamese_network = siamese_network.to(device)

### Modell trainieren

In [8]:
dest = os.path.join(r"C:\Users\mariu\Desktop\projects", "siamese_network", ARCHITECTURE.lower() + "_" + DATASET.lower() + "_" + datetime.today().strftime('%Y%m%d_%H%M%S'))
os.makedirs(dest)

def save_model(epoch=0):
    d = {
        "architecture": ARCHITECTURE,
        "dataset": DATASET,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "margin": MARGIN,
        "learning_rate": LEARNING_RATE,
        "output_dim": OUTPUT_DIM,
        "optimizer": OPTIMIZER,
        "optimizer_momentum": MOMENTUM,
        "optimizer_weight_decay": WEIGHT_DECAY,
        "scheduler": SCHEDULER,
        "scheduler_modulo": SCHEDULER_MODULO,
        "scheduler_gamma": GAMMA,
        "date": datetime.today().strftime('%Y%m%d_%H%M%S'),
    }

    with open(os.path.join(dest, "json_" + str(epoch) + ".json"), "w") as file:
        json.dump(d, file)

    torch.save(siamese_network.state_dict(), os.path.join(dest, "state_dict_" + str(epoch) + ".pt"))

In [None]:
criterion = TripletLoss(margin=MARGIN)

match OPTIMIZER:
    case "adam":
        opt = torch.optim.Adam(siamese_network.parameters(), lr=LEARNING_RATE)
    case "sgd":
        opt = torch.optim.SGD(
            # filter(lambda p: p.requires_grad, siamese_network.parameters()),
            siamese_network.parameters(),
            LEARNING_RATE,
            momentum=MOMENTUM,
            weight_decay=WEIGHT_DECAY
        )
        
match SCHEDULER:
    case "exponentiallr":
        sched = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.1)
    case "custom_cosine_annealing_without_restart":
        lambda_sched = lambda epoch: 0.5 * (1 + math.cos(math.pi * epoch / EPOCHS))
        sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda_sched)

In [None]:
for epoch in range(EPOCHS):
    losses = []

    tqdm_data_loader = tqdm(data_loader, unit="batch")
    tqdm_data_loader.set_description(f"Epoch {epoch}/{EPOCHS}")
    for i, (anchor, positive, negative, _, _, _) in enumerate(tqdm_data_loader):
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        anchor_output = siamese_network(anchor)
        positive_output = siamese_network(positive)
        negative_output = siamese_network(negative)

        if ARCHITECTURE == "inceptionv3":
            anchor_output = anchor_output.logits
            positive_output = positive_output.logits
            negative_output = negative_output.logits

        loss = criterion(anchor_output, positive_output, negative_output)

        if (i+1) % 10 == 0:
            wandb.log({
                "loss": loss.item(),
                "learning_rate": sched.get_last_lr()[-1], 
                "epoch": epoch,
                "batch": i,
            })

        opt.zero_grad()
        loss.backward()
        opt.step()

        losses.append(loss.item())

        tqdm_data_loader.set_postfix({
            "Loss": loss.item()
        })
    
    val_losses = []
    siamese_network.eval()
    with torch.no_grad():
        tqdm_data_loader_val = tqdm(data_loader_val, unit="batch")
        tqdm_data_loader_val.set_description(f"Epoch {epoch}/{EPOCHS}")
        for i, (anchor, positive, negative, _, _, _) in enumerate(tqdm_data_loader_val):
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            anchor_output = siamese_network(anchor)
            positive_output = siamese_network(positive)
            negative_output = siamese_network(negative)

            loss = criterion(anchor_output, positive_output, negative_output)

            val_losses.append(loss.item())
    siamese_network.train()

    wandb.log({
        "loss_avg": np.mean(losses),
        "val_loss_avg": np.mean(val_losses), 
    })

    save_model(epoch=epoch)

    if (epoch+1) % SCHEDULER_MODULO == 0: # epoch % 5 == 0:
        sched.step()

In [None]:
wandb.finish()

### Modell speichern

In [17]:
# value_str = ARCHITECTURE.lower() + "_margin" + str(MARGIN) + "_learningrate" + str(LEARNING_RATE).split(".", 1)[1] + "_outputdim" + str(OUTPUT_DIM) + "_epochs" + str(EPOCHS) + "_samples" + str(IMAGE_NUM) + "_batchsize" + str(BATCH_SIZE) + "_date" + datetime.today().strftime('%Y%m%d_%H%M%S')

# d = {
#     "architecture": ARCHITECTURE,
#     "margin": MARGIN,
#     "learing rate": LEARNING_RATE,
#     "output dimension": OUTPUT_DIM,
#     "epochs": EPOCHS,
#     "samples": IMAGE_NUM,
#     "batchsize": BATCH_SIZE,
#     "dataset": "ImageNet",
#     "classes": CLASSES,
#     "loss": losses,
#     "date": datetime.today().strftime('%Y%m%d_%H%M%S'),
# }

# with open(r"C:\Users\mariu\Desktop\statedict_siamese_network_" + value_str + ".json", "w") as file:
#     json.dump(d, file)

# torch.save(siamese_network.state_dict(), r"C:\Users\mariu\Desktop\statedict_siamese_network_" + value_str + ".pt")