In [2]:
from functools import partial
import os
from PIL import Image
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch
from torch import nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.linalg import norm
from torch.optim import Adam, SGD
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping
from sklearn.model_selection import train_test_split


base = "./"  # where the data is located
img_dir = base + "food"  # where you unzipped food.zip
train_path = base + "train_triplets.txt"
test_path = base + "test_triplets.txt"

AVAIL_GPUS = 1 if torch.cuda.device_count() else 0


def get_split():
    triplets = np.loadtxt(train_path, delimiter=" ").astype(int)

    train_triplets, val_triplets = train_test_split(
        triplets, test_size=0.1, shuffle=True
    )

    return train_triplets, val_triplets
    

class TripletsDataset(Dataset):
    def __init__(self, triplets, img_dir, transform=None):
        self.triplets = triplets
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        triplet = self.triplets[idx]

        paths = (os.path.join(self.img_dir, f"{i}".rjust(5, "0") + ".jpg") for i in triplet)
        images = (Image.open(path) for path in paths)

        if self.transform:
            images = (self.transform(image) for image in images)

        img_a, img_b, img_c = images

        return img_a, img_b, img_c


# this defines our network.
# we use a pytorch lightning LightningModule instead of a nn.Module
# for its convenient features. Hence why we implement many of the 
# methods below, they are reserved by lightning and used by the 
# trainer
class SimilarityNet(pl.LightningModule):
    def __init__(self, backbone, lr=1e-3, batch_size=8):
        super().__init__()

        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size
        self.learning_rate = lr

        self.features = nn.Sequential(
            *list(backbone.children())[:-1]
        )  # this is a resnet without its classification layers
        self.features.requires_grad_(False)
        embedding_dim = list(backbone.children())[-1].in_features
        self.embedding = nn.Sequential(
            nn.Linear(in_features=embedding_dim, out_features=1024)
        )
        self.loss = nn.TripletMarginLoss(margin=5.0)
        self.val_loss = partial(F.triplet_margin_loss, margin=0)

    def forward(self, img_a, img_b, img_c):

        phi_a = self.features(img_a)
        phi_b = self.features(img_b)
        phi_c = self.features(img_c)

        phi_a = phi_a.view(phi_a.size(0), -1)
        phi_b = phi_b.view(phi_b.size(0), -1)
        phi_c = phi_c.view(phi_c.size(0), -1)

        embedded_a = self.embedding(phi_a)
        embedded_b = self.embedding(phi_b)
        embedded_c = self.embedding(phi_c)

        return embedded_a, embedded_b, embedded_c

    def training_step(self, batch):
        img_a, img_b, img_c = batch
        embedded_a, embedded_b, embedded_c = self(img_a, img_b, img_c)
        loss = self.loss(embedded_a, embedded_b, embedded_c)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, *args):
        img_a, img_b, img_c = batch
        embedded_a, embedded_b, embedded_c = self(img_a, img_b, img_c)
        loss = self.val_loss(embedded_a, embedded_b, embedded_c)
        self.log("val_loss", loss)

        d_ab = norm(embedded_a - embedded_b, axis=-1).squeeze()
        d_ac = norm(embedded_a - embedded_c, axis=-1).squeeze()
        acc = (d_ab < d_ac).sum()/self.batch_size
        self.log("val_acc", acc)

        return {"val_loss": loss, "val_acc": acc}

    def predict_step(self, batch, *args):
        img_a, img_b, img_c = batch
        embedded_a, embedded_b, embedded_c = self(img_a, img_b, img_c)
        return embedded_a, embedded_b, embedded_c
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.learning_rate)

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack(
            [x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack(
            [x["val_acc"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log("val_acc", avg_acc)


if __name__ == "__main__":
    learning_rate = 1e-4 
    batch_size = 128
    num_workers = 16
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((242, 354)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )  # could also try augmentation

    train_triplets, val_triplets = get_split()
    train_dataset = TripletsDataset(train_triplets, img_dir, transform)
    val_dataset = TripletsDataset(val_triplets, img_dir, transform)

    train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            persistent_workers=True
        )
    val_loader = DataLoader(
            dataset=val_dataset,
            batch_size=batch_size,
            num_workers=num_workers,
            persistent_workers=True,
            drop_last=True
        )

    backbone = models.resnet50(pretrained=True)

    model = SimilarityNet(
        backbone=backbone,
        lr=learning_rate,
        batch_size=batch_size,
    )

    bar = TQDMProgressBar(refresh_rate=1)
    early_stop = EarlyStopping(
        monitor="val_acc", mode="max", min_delta=0.005, patience=1, verbose=True
    )

    trainer = Trainer(
        # fast_dev_run=True,
        accelerator="gpu", 
        devices=AVAIL_GPUS, 
        auto_select_gpus=True,
        min_epochs=1,
        max_epochs=50,
        callbacks=[bar, early_stop],
        auto_lr_find=True,
        auto_scale_batch_size=False
    )

    # trainer.tune(model)
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    # save weights
    torch.save(model.state_dict(), 'resnet50_nonlinear.pth')

    # predict
    test_triplets = np.loadtxt(train_path, delimiter=" ").astype(int)
    test_dataset = TripletsDataset(test_triplets, img_dir, transform)
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        persistent_workers=True
    )
    
    predictions = []
    with torch.no_grad():
        model.cuda()
        model.eval()
        for img_a, img_b, img_c in tqdm(test_loader):
            img_a, img_b, img_c = img_a.cuda(), img_b.cuda(), img_c.cuda()
            embedded_a, embedded_b, embedded_c = model(img_a, img_b, img_c)
            d_ab = norm(embedded_a - embedded_b, axis=-1).squeeze()
            d_ac = norm(embedded_a - embedded_c, axis=-1).squeeze()
            predictions += list((1 * (d_ab < d_ac)))
            break

    predictions = [int(x) for x in predictions]
    df_pred = pd.DataFrame(predictions)
    df_pred.to_csv("resnet50_nonlinear.txt", index=False, header=None)

Auto select gpus: [0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type              | Params
------------------------------------------------
0 | features  | Sequential        | 23.5 M
1 | embedding | Sequential        | 29.3 M
2 | loss      | TripletMarginLoss | 0     
------------------------------------------------
52.8 M    Trainable params
0         Non-trainable params
52.8 M    Total params
211.116   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

RuntimeError: CUDA out of memory. Tried to allocate 680.00 MiB (GPU 0; 10.92 GiB total capacity; 8.46 GiB already allocated; 201.38 MiB free; 8.65 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF