In [9]:
import os
import torch
from retrievalmodels import RetrievalModel
from modelfinetuning import fine_tune_with_identity
from compute_embedding_celebA import compute_embeddings_from_images

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
PARENT_DIRNAME = os.path.expanduser("~/image-processing-project/")
IMAGE_DIR = os.path.join(PARENT_DIRNAME, "data/img_align_celeba/")
STORAGE_DATA_DIRNAME = os.path.join(PARENT_DIRNAME, "fine_tuning/data_for_fine_tuning")
MODEL_DIR = os.path.join(PARENT_DIRNAME, "fine_tuning/models")

In [12]:
NUM_WORKERS = 4
LEARNING_RATE = 0.001

In [13]:
train_loader = torch.load(os.path.join(STORAGE_DATA_DIRNAME, "train_loader.pth"))
query_loader = torch.load(os.path.join(STORAGE_DATA_DIRNAME, "query_loader.pth"))
gallery_loader = torch.load(os.path.join(STORAGE_DATA_DIRNAME, "gallery_loader.pth"))

  train_loader = torch.load(os.path.join(STORAGE_DATA_DIRNAME, "train_loader.pth"))
  query_loader = torch.load(os.path.join(STORAGE_DATA_DIRNAME, "query_loader.pth"))
  gallery_loader = torch.load(os.path.join(STORAGE_DATA_DIRNAME, "gallery_loader.pth"))


# ResNet50

In [14]:
model = RetrievalModel(backbone="resnet50", embedding_dim=128).to(device)

In [15]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)

In [16]:
fine_tune_with_identity(
    model=model,
    dataloader=train_loader,
    optimizer=optimizer,
    num_epochs=10,
    device=device,
    num_threads=NUM_WORKERS
)

torch.save(model.state_dict(), os.path.join(MODEL_DIR, "resnet50_identity.pth"))

Epoch 1/10: 100%|██████████| 2334/2334 [20:10<00:00,  1.93it/s]


Epoch 1, Triplet Loss: 0.3921


Epoch 2/10: 100%|██████████| 2334/2334 [20:14<00:00,  1.92it/s]


Epoch 2, Triplet Loss: 0.2603


Epoch 3/10: 100%|██████████| 2334/2334 [20:18<00:00,  1.92it/s]


Epoch 3, Triplet Loss: 0.1965


Epoch 4/10: 100%|██████████| 2334/2334 [20:22<00:00,  1.91it/s]


Epoch 4, Triplet Loss: 0.1800


Epoch 5/10: 100%|██████████| 2334/2334 [20:19<00:00,  1.91it/s]


Epoch 5, Triplet Loss: 0.1496


Epoch 6/10: 100%|██████████| 2334/2334 [20:20<00:00,  1.91it/s]


Epoch 6, Triplet Loss: 0.1336


Epoch 7/10: 100%|██████████| 2334/2334 [20:20<00:00,  1.91it/s]


Epoch 7, Triplet Loss: 0.1203


Epoch 8/10: 100%|██████████| 2334/2334 [20:12<00:00,  1.92it/s]


Epoch 8, Triplet Loss: 0.1146


Epoch 9/10: 100%|██████████| 2334/2334 [20:07<00:00,  1.93it/s]


Epoch 9, Triplet Loss: 0.1065


Epoch 10/10: 100%|██████████| 2334/2334 [20:13<00:00,  1.92it/s]

Epoch 10, Triplet Loss: 0.1032





In [17]:
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "resnet50_identity.pth")))

  model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "resnet50_identity.pth")))


<All keys matched successfully>

In [18]:
full_embeddings, full_labels = compute_embeddings_from_images(
    model=model,
    device=device
)

torch.save(full_embeddings, os.path.join(STORAGE_DATA_DIRNAME, "full_embeddings_resnet.pth"))
torch.save(full_labels, os.path.join(STORAGE_DATA_DIRNAME, "full_labels_resnet.pth"))

# MobileNetV2

In [4]:
model = RetrievalModel(backbone="mobilenet_v2", embedding_dim=128).to(device)



In [23]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)

In [24]:
fine_tune_with_identity(
    model=model,
    dataloader=train_loader,
    optimizer=optimizer,
    num_epochs=23,
    device=device,
    num_threads=NUM_WORKERS
)

torch.save(model.state_dict(), os.path.join(MODEL_DIR, "mobilenet_v2_identity.pth"))

Epoch 1/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.62it/s]


Epoch 1, Triplet Loss: 0.4539


Epoch 2/23: 100%|██████████| 2334/2334 [08:27<00:00,  4.60it/s]


Epoch 2, Triplet Loss: 0.3837


Epoch 3/23: 100%|██████████| 2334/2334 [08:28<00:00,  4.59it/s]


Epoch 3, Triplet Loss: 0.3513


Epoch 4/23: 100%|██████████| 2334/2334 [08:26<00:00,  4.61it/s]


Epoch 4, Triplet Loss: 0.3369


Epoch 5/23: 100%|██████████| 2334/2334 [08:28<00:00,  4.59it/s]


Epoch 5, Triplet Loss: 0.3254


Epoch 6/23: 100%|██████████| 2334/2334 [08:28<00:00,  4.59it/s]


Epoch 6, Triplet Loss: 0.3144


Epoch 7/23: 100%|██████████| 2334/2334 [08:28<00:00,  4.59it/s]


Epoch 7, Triplet Loss: 0.3043


Epoch 8/23: 100%|██████████| 2334/2334 [08:26<00:00,  4.60it/s]


Epoch 8, Triplet Loss: 0.2968


Epoch 9/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.63it/s]


Epoch 9, Triplet Loss: 0.2919


Epoch 10/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.62it/s]


Epoch 10, Triplet Loss: 0.2869


Epoch 11/23: 100%|██████████| 2334/2334 [08:25<00:00,  4.62it/s]


Epoch 11, Triplet Loss: 0.2779


Epoch 12/23: 100%|██████████| 2334/2334 [08:22<00:00,  4.65it/s]


Epoch 12, Triplet Loss: 0.2746


Epoch 13/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.62it/s]


Epoch 13, Triplet Loss: 0.2686


Epoch 14/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.62it/s]


Epoch 14, Triplet Loss: 0.2650


Epoch 15/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.63it/s]


Epoch 15, Triplet Loss: 0.2590


Epoch 16/23: 100%|██████████| 2334/2334 [08:22<00:00,  4.64it/s]


Epoch 16, Triplet Loss: 0.2538


Epoch 17/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.62it/s]


Epoch 17, Triplet Loss: 0.2490


Epoch 18/23: 100%|██████████| 2334/2334 [08:24<00:00,  4.63it/s]


Epoch 18, Triplet Loss: 0.2372


Epoch 19/23: 100%|██████████| 2334/2334 [08:23<00:00,  4.63it/s]


Epoch 19, Triplet Loss: 0.2253


Epoch 20/23: 100%|██████████| 2334/2334 [08:20<00:00,  4.66it/s]


Epoch 20, Triplet Loss: 0.2046


Epoch 21/23: 100%|██████████| 2334/2334 [08:20<00:00,  4.66it/s]


Epoch 21, Triplet Loss: 0.1940


Epoch 22/23: 100%|██████████| 2334/2334 [08:22<00:00,  4.65it/s]


Epoch 22, Triplet Loss: 0.1870


Epoch 23/23: 100%|██████████| 2334/2334 [08:23<00:00,  4.64it/s]

Epoch 23, Triplet Loss: 0.1794





In [5]:
model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "mobilenet_v2_identity.pth")))

  model.load_state_dict(torch.load(os.path.join(MODEL_DIR, "mobilenet_v2_identity.pth")))


<All keys matched successfully>

In [6]:
full_embeddings, full_labels = compute_embeddings_from_images(
    model=model,
    device=device
)

torch.save(full_embeddings, os.path.join(STORAGE_DATA_DIRNAME, "full_embeddings_mobilenet.pth"))
torch.save(full_labels, os.path.join(STORAGE_DATA_DIRNAME, "full_labels_mobilenet.pth"))