In [1]:
import os
import sys
import pickle
from functools import partial
from PIL import Image
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch
from torch import nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.linalg import norm
from torch.optim import Adam, SGD
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping
from sklearn.model_selection import train_test_split


base = "../"  # where the data is located
img_dir = base + "food"  # where you unzipped food.zip
train_path = base + "train_triplets.txt"
test_path = base + "test_triplets.txt"


def get_split(val_ratio):
    triplets = np.loadtxt(train_path, delimiter=" ").astype(int)
    train_triplets, val_triplets = train_test_split(
        triplets, test_size=val_ratio, random_state=489, shuffle=True
    )
    return train_triplets, val_triplets


def get_features(name):
    pkl_path = f"{name}_features.pkl"

    if name == "ResNet18":
        backbone = models.resnet18(pretrained=True)
    elif name == "ResNet34":
        backbone = models.resnet34(pretrained=True)
    elif name == "ResNet50":
        backbone = models.resnet50(pretrained=True)
    elif name == "ResNet101":
        backbone = models.resnet101(pretrained=True)
    elif name == "ResNet152":
        backbone = models.resnet152(pretrained=True)
    elif name == "ViT_b_16":
        backbone = models.vit_b_16(pretrained=True)
    elif name == "ViT_b_32":
        backbone = models.vit_b_32(pretrained=True)
    else:
        sys.exit("Error: This model is not implemented.")

    if name.startswith("ResNet"):
        features_dim = list(backbone.children())[-1].in_features
    elif name.startswith("ViT"):
        features_dim = list(backbone.heads.children())[0].in_features

    if os.path.exists(pkl_path):
        # fetch precomputed features from earlier run
        with open(pkl_path, "rb") as f:
            features = pickle.load(f)
    else:
        # compute features
        if name.startswith("ResNet"):
            feature_map = nn.Sequential(*list(backbone.children())[:-1])
            features_dim = list(backbone.children())[-1].in_features
            tfms = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize((242, 354)),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )  # preprocessing function
        elif name.startswith("ViT"):
            feature_map = nn.Sequential(*list(backbone.children())[:-1])[1]
            features_dim = list(backbone.heads.children())[0].in_features
            tfms = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize((224, 224)),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                ]
            )  # preprocessing function

        feature_map.cuda().eval()

        n_imgs = 10_000
        features = torch.empty((n_imgs, features_dim))

        with torch.no_grad():
            for i in tqdm(range(n_imgs)):
                path = os.path.join(img_dir, f"{str(i).rjust(5, '0')}.jpg")
                if name.startswith("ResNet"):
                    img = tfms(Image.open(path)).unsqueeze(0).cuda()
                    phi = feature_map(img)  # forward pass
                elif name.startswith("ViT"):
                    img = tfms(Image.open(path)).unsqueeze(0).cuda()
                    x = backbone._process_input(img)
                    batch_class_token = backbone.class_token.expand(x.shape[0], -1, -1)
                    x = torch.cat([batch_class_token, x], dim=1)
                    phi = feature_map(x)[:, 0]
                features[i] = phi.squeeze()

        features = features.cpu().float()

        # save features
        with open(pkl_path, "wb") as f:
            pickle.dump(features, f)

    return features, features_dim


class TripletsDataset(Dataset):
    def __init__(self, triplets, features):
        self.triplets = triplets
        self.features = features

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        triplet = self.triplets[idx]
        a, b, c = self.features[triplet]
        return a, b, c


# this defines our network.
# we use a pytorch lightning LightningModule instead of a nn.Module
# for its convenient features. Hence why we implement many of the
# methods below, they are reserved by lightning and used by the
# trainer
class SimilarityNet(pl.LightningModule):
    def __init__(
        self,
        features_dim,
        margin=5.0,
        embedding_dim=1024,
        lr=1e-3,
        momentum=0.9,
        nesterov=True,
        weight_decay=1e-3,
        batch_size=8,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.batch_size = batch_size
        self.learning_rate = lr
        self.momentum = momentum
        self.nesterov = nesterov
        self.weight_decay = weight_decay

        self.embedding = nn.Sequential(
            nn.Linear(in_features=features_dim, out_features=2048),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=2048, out_features=2048),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=2048, out_features=embedding_dim),
        )
        self.loss = nn.TripletMarginLoss(margin=margin)
        self.val_loss = partial(F.triplet_margin_loss, margin=0)

    def forward(self, a, b, c):
        embedded_a = self.embedding(a)
        embedded_b = self.embedding(b)
        embedded_c = self.embedding(c)

        return embedded_a, embedded_b, embedded_c

    def training_step(self, batch):
        a, b, c = batch
        embedded_a, embedded_b, embedded_c = self(a, b, c)
        loss = self.loss(embedded_a, embedded_b, embedded_c)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, idx):
        a, b, c = batch
        embedded_a, embedded_b, embedded_c = self(a, b, c)
        loss = self.val_loss(embedded_a, embedded_b, embedded_c)
        self.log("val_loss", loss)

        d_ab = norm(embedded_a - embedded_b, axis=-1).squeeze()
        d_ac = norm(embedded_a - embedded_c, axis=-1).squeeze()
        acc = (d_ab < d_ac).float().mean()
        self.log("val_acc", acc)

        return {"val_loss": loss, "val_acc": acc}

    def predict_step(self, batch, idx):
        a, b, c = batch
        embedded_a, embedded_b, embedded_c = self(a, b, c)
        d_ab = norm(embedded_a - embedded_b, axis=-1).squeeze()
        d_ac = norm(embedded_a - embedded_c, axis=-1).squeeze()
        pred = (d_ab <= d_ac).int()
        return pred

    def configure_optimizers(self):
        return SGD(
            self.parameters(),
            lr=self.learning_rate,
            momentum=self.momentum,
            nesterov=self.nesterov,
            weight_decay=self.weight_decay,
        )

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log("val_acc", avg_acc)


# first we compute the embedded images with a pretrained network

backbone = "ResNet152"  # choose one of ResNet[18, 34, 50, 101, 152], ViT_[b_16, b_32]
features, features_dim = get_features(backbone)

# hyperparameters
embedding_dim = 4096
learning_rate = 1e-3
momentum = 0.9
nesterov = True
weight_decay = 1e-3
margin = 5.0
val_ratio = 0.2
batch_size = 4096
num_workers = 32

train_triplets, val_triplets = get_split(val_ratio)
train_dataset = TripletsDataset(train_triplets, features)
val_dataset = TripletsDataset(val_triplets, features)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
)


model = SimilarityNet(
    features_dim=features_dim,
    margin=margin,
    lr=learning_rate,
    momentum=momentum,
    nesterov=nesterov,
    weight_decay=weight_decay,
    batch_size=batch_size,
)

bar = TQDMProgressBar(refresh_rate=1)
early_stop = EarlyStopping(
    monitor="val_acc", mode="max", min_delta=0.002, patience=30, verbose=True
)

trainer = Trainer(
    # fast_dev_run=True, # uncomment to debug
    accelerator="gpu",
    devices=torch.cuda.device_count(),
    auto_select_gpus=True,
    min_epochs=1,
    max_epochs=1000,
    callbacks=[bar, early_stop],
    auto_lr_find=True,
    auto_scale_batch_size=False,
)

# trainer.tune(model)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# predict
test_triplets = np.loadtxt(test_path, delimiter=" ").astype(int)
test_dataset = TripletsDataset(test_triplets, features)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    shuffle=False,
)

predictions = trainer.predict(model, test_loader)
predictions = torch.cat(predictions).tolist()

df_pred = pd.DataFrame(predictions)
sub_path = f"{backbone}.txt"
df_pred.to_csv(f"../predictions/{sub_path}", index=False, header=None)


Auto select gpus: [0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | embedding | Sequential        | 10.5 M
1 | loss      | TripletMarginLoss | 0     
------------------------------------------------
10.5 M    Trainable params
0         Non-trainable params
10.5 M    Total params
41.964    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_acc improved. New best score: 0.614


Validation: 0it [00:00, ?it/s]

Metric val_acc improved by 0.005 >= min_delta = 0.002. New best score: 0.620


OSError: [Errno 122] Disk quota exceeded