In [27]:
import os
import pickle
from PIL import Image
from tqdm import tqdm

import pandas as pd
import numpy as np
import torch
from torch import nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cross_entropy, softmax
from torch.optim import Adam
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping
from sklearn.model_selection import train_test_split


base = "../"  # where the data is located
img_dir = base + "food/"  # where you unzipped food.zip
train_path = base + "train_triplets.txt"
test_path = base + "test_triplets.txt"


def get_split():
    triplets = np.loadtxt(train_path, delimiter=" ").astype(int)
    n_samples = len(triplets)
    triplets = [[a, b, c] if i == 1 else [a, c, b] for a, b, c in triplets for i in (1, 0)]
    labels = torch.tensor(np.resize([1, 0], 2*n_samples))

    train_triplets, val_triplets, train_labels, val_labels = train_test_split(
        triplets, labels, test_size=0.1, shuffle=True
    )

    return train_triplets, train_labels, val_triplets, val_labels
    

class TripletsDataset(Dataset):
    def __init__(self, triplets, features, labels=None):
        self.triplets = triplets
        self.features = features
        if labels is not None:
            self.train = True
            self.labels = labels
        else:
            self.train = False

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        triplet = self.triplets[idx]
        a, b, c = self.features[triplet]
        phi = torch.cat((a, b, c), dim=0)
        if self.train:
            label = self.labels[idx]
            return phi, label

        return phi


# this defines our network.
# we use a pytorch lightning LightningModule instead of a nn.Module
# for its convenient features. Hence why we implement many of the 
# methods below, they are reserved by lightning and used by the 
# trainer
class SimilarityNet(pl.LightningModule):
    def __init__(self, features_dim, lr=1e-3, batch_size=8):
        super().__init__()

        self.batch_size = batch_size
        self.learning_rate = lr

        self.classifier = nn.Sequential(
            nn.Linear(in_features=3*features_dim, out_features=2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=0.5),
            nn.Linear(1024, 2)
        )

        self.loss = nn.CrossEntropyLoss()
        self.val_loss = cross_entropy

    def forward(self, x):
        return self.classifier(x)

    def training_step(self, batch):
        x, y = batch
        logit = self(x)
        loss = self.loss(logit, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, idx):
        x, y = batch
        logit = self(x)
        loss = self.val_loss(logit, y)
        self.log("val_loss", loss)

        probs = softmax(logit, dim=1)
        pred = probs.argmax(dim=1)
        acc = (pred != y).float().mean()
        self.log("val_acc", acc)

        return {"val_loss": loss, "val_acc": acc}

    def predict_step(self, batch, idx):
        logit = self.classifier(batch)
        probs = softmax(logit, dim=1)
        pred = probs.argmax(dim=1)
        return pred

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.learning_rate)

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack(
            [x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack(
            [x["val_acc"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)
        self.log("val_acc", avg_acc)



# first we compute the embedded images with a pretrained resnet
backbone = models.resnet101(pretrained=True)
features_dim = list(backbone.children())[-1].in_features

if os.path.exists("r101_features.pkl"): 
    # fetch precopmuted features from earlier run
    with open("r101_features.pkl", "rb") as f:
        features = pickle.load(f)
else:
    # compute features
    feature_map = nn.Sequential(*list(backbone.children())[:-1])  # create embedding map by omitting final classification layer
    feature_map.cuda()

    tfms = transforms.Compose(
        [
            transforms.Resize((242, 354)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )  # preprocessing function

    n_imgs = 10_000
    features = torch.empty((n_imgs, features_dim))
    with torch.no_grad():
        for i in tqdm(range(n_imgs)):
            path = img_dir + str(i).rjust(5, "0") + ".jpg"
            img = tfms(Image.open(path)).unsqueeze(0).cuda()
            phi = feature_map(img)  # forward pass
            features[i] = phi.squeeze().cpu().float()

    # save features
    with open("r101_features.pkl", "wb") as f:
        pickle.dump(features, f) 


# training our classifier
learning_rate = 1e-4
batch_size = 2048
num_workers = 32

train_triplets, train_labels, val_triplets, val_labels = get_split()
train_dataset = TripletsDataset(
    triplets=train_triplets,
    features=features,
    labels=train_labels
    )
val_dataset = TripletsDataset(
    triplets=val_triplets,
    features=features,
    labels=val_labels
    )

train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        persistent_workers=True,
    )
val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        persistent_workers=True,
    )


model = SimilarityNet(
    features_dim=features_dim,
    lr=learning_rate,
    batch_size=batch_size,
)

bar = TQDMProgressBar(refresh_rate=1)
early_stop = EarlyStopping(
    monitor="val_acc", mode="max", min_delta=0.002, patience=10, verbose=True
)

trainer = Trainer(
    # fast_dev_run=True,
    accelerator="gpu", 
    devices=[0], 
    # auto_select_gpus=True,
    min_epochs=1,
    max_epochs=250,
    callbacks=[bar, early_stop],
    auto_lr_find=True,
    auto_scale_batch_size=False
)

# trainer.tune(model, train_loader, val_loader)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# predict
test_triplets = np.loadtxt(test_path, delimiter=" ").astype(int)
test_dataset = TripletsDataset(test_triplets, features)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    persistent_workers=True,
    shuffle=False
)

predictions = trainer.predict(model,  test_loader)
predictions = torch.cat(predictions, dim=0).tolist()

df_pred = pd.DataFrame(predictions)
df_pred.to_csv("../predictions/r101_classifier.txt", index=False, header=None)

100%|██████████| 10000/10000 [06:35<00:00, 25.30it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name       | Type             | Params
------------------------------------------------
0 | classifier | Sequential       | 14.7 M
1 | loss       | CrossEntropyLoss | 0     
------------------------------------------------
14.7 M    Trainable params
0         Non-trainable params
14.7 M    Total params
58.741    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_acc improved. New best score: 0.501


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/u/hehlif/miniconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>Traceback (most recent call last):

Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/u/he

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/u/hehlif/miniconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    Exception ignored in: self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>

  File "/u/hehlif/m

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/u/hehlif/miniconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shut

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/u/hehlif/miniconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb419c4e3a0>
Traceback (most recent call last):
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/u/hehlif/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shut

Validation: 0it [00:00, ?it/s]

Monitored metric val_acc did not improve in the last 10 records. Best score: 0.501. Signaling Trainer to stop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 53it [00:00, ?it/s]