In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
data_dir='data'
classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
n_epochs = 50
emsize = 128
batch_size = 64
margin = 1.
load = None

In [None]:
import torch.utils.data as data
import os
from datasets import ImageDataset, pad_to_size

from fastai.vision import *
from fastai.callbacks import *

In [None]:
from torchvision import transforms
import torch.nn.functional as F
import torch

tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(*imagenet_stats),
    lambda x: pad_to_size(x, (512, 512))
])
train_dataset = ImageDataset(
    f"{data_dir}/classy_coconut/train",
    classes,
    tfms=tfms
)
val_dataset = ImageDataset(
    f"{data_dir}/classy_coconut/val",
    classes,
    tfms=tfms
)
n_classes = len(classes)

print(len(train_dataset))
print(len(val_dataset))

In [None]:
from datasets import SiameseImage

siamese_train_dataset = SiameseImage(train_dataset, True)
siamese_val_dataset = SiameseImage(val_dataset, False)

from networks import EmbeddingNet, SiameseNet, EmbeddingNetPretrained
from losses import ContrastiveLoss

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
embedding_net = EmbeddingNetPretrained(models.resnet50, emsize=emsize)
model = SiameseNet(embedding_net)
loss_fn = ContrastiveLoss(margin)

In [None]:
from datasets import SiameseImageList
lls = SiameseImageList.from_datasets(siamese_train_dataset, siamese_val_dataset)

In [None]:
dbunch = (lls
          .transform(get_transforms())
          .databunch(bs=batch_size, device=device, num_workers=8)
          .normalize(imagenet_stats)
         )

In [None]:
callback_fns = [
    CSVLogger,
    partial(SaveModelCallback, every="improvement", monitor="valid_loss"),
    partial(EarlyStoppingCallback, monitor="valid_loss", min_delta=0.0005, patience=5)
]
learner = Learner(dbunch, model, loss_func=loss_fn, callback_fns=callback_fns)

In [None]:
if load is not None:
    learner.load(load)

In [None]:
learner.fit_one_cycle(n_epochs, 1e-3)
learner.save("siamese")

In [None]:
metrics = learner.validate()
metrics_dict = {"siamese": {"loss": float(metrics[0])}}

import json

with open("metrics.json", "w") as f:
    json.dump(metrics_dict, f)