## Settings

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
data_dir='data'
classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
n_epochs = 100
emsize = 128
batch_size1 = 16
batch_size2 = 8
margin = 1.
load = None
nc1 = 4
ns1 = 4
nc2 = 4
ns2 = 2

## Training

In [None]:
import torch.utils.data as data
import os
from data import ImageDataset

from fastai.vision import *
from fastai.callbacks import *

from networks import siamese_embedding_learner
from data import BalancedBatchSampler, ImageEmbedList
from torch.utils.data.dataloader import default_collate

import wandb
from wandb.fastai import WandbCallback
wandb.init(project="embedders-vision")

In [None]:
from torchvision import transforms
import torch.nn.functional as F
import torch

tfms = transforms.Compose([
    transforms.ToTensor(),
    Image,
    partial(crop_pad, size=512, padding_mode="zeros"),
#     transforms.Normalize(*imagenet_stats)
])
train_dataset = ImageDataset(
    f"{data_dir}/classy_coconut/train",
    classes,
    tfms=tfms
)
val_dataset = ImageDataset(
    f"{data_dir}/classy_coconut/val",
    classes,
    tfms=tfms
)
n_classes = len(classes)

print(f"Number of items in the training data set: {len(train_dataset)}")
print(f"Number of items in the validation data set: {len(val_dataset)}")

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
lls = (ImageEmbedList.from_folder(f"{data_dir}/classy_coconut")
       .split_by_folder(valid="val")
       .label_from_folder()
)
labels_train = torch.LongTensor([lls.train[i][1].data for i in range(len(lls.train))])
labels_val = torch.LongTensor([lls.valid[i][1].data for i in range(len(lls.valid))])

In [None]:
def pad_collate(batch, size=512, padding_mode="zeros"):
    cp = partial(crop_pad, size=size, padding_mode=padding_mode)
    batch = [(cp(inp), lab) for inp, lab in batch]
    return default_collate(to_data(batch))

### Phase 1

In [None]:
dbunch = (lls
          .transform(get_transforms())
          .databunch("siamese",
                     bsampler=BalancedBatchSampler(labels_train, nc1, ns1),
                     val_bsampler=BalancedBatchSampler(labels_val, nc1, ns1),
                     collate_fn=pad_collate,
                     device=device,
                     num_workers=8)
          .normalize(imagenet_stats)
         )

In [None]:
callback_fns = [
    partial(CSVLogger, append=True),
    partial(SaveModelCallback, every="improvement", monitor="valid_loss"),
    partial(EarlyStoppingCallback, monitor="valid_loss", min_delta=0.0005, patience=5),
    partial(WandbCallback, input_type='images')
]
learner = siamese_embedding_learner(dbunch, models.resnet50, emsize, margin, callback_fns=callback_fns).to_fp16()

In [None]:
if load is not None:
    learner.load(load)

In [None]:
learner.freeze()
learner.fit_one_cycle(n_epochs, 1e-3)
learner.save("embedder")
learner.destroy()

### Phase 2

In [None]:
dbunch = (lls
          .transform(get_transforms())
          .databunch("siamese",
                     bsampler=BalancedBatchSampler(labels_train, nc2, ns2),
                     val_bsampler=BalancedBatchSampler(labels_val, nc2, ns2),
                     collate_fn=pad_collate,
                     device=device,
                     num_workers=8)
          .normalize(imagenet_stats)
         )

In [None]:
learner = siamese_embedding_learner(dbunch, models.resnet50, emsize, margin, callback_fns=callback_fns).to_fp16()
learner.load("embedder")

In [None]:
learner.freeze_to(1)
learner.fit_one_cycle(n_epochs, [1e-5, 1e-4, 1e-3])
learner.save("embedder")
learner.export()

## Evaluation

In [None]:
metrics = learner.validate()
metrics_dict = {"siamese": {"loss": float(metrics[0])}}

import json

with open("metrics.json", "w") as f:
    json.dump(metrics_dict, f)