<a href="https://colab.research.google.com/github/ericwolter/series-photo-selection/blob/main/sps_resnet50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install -qqq kaggle

In [3]:
!mkdir ~/.kaggle

In [4]:
!cp /content/drive/MyDrive/kaggle.json ~/.kaggle/kaggle.json

In [5]:
!kaggle datasets download ericwolter/triage

Downloading triage.zip to /content
100% 12.7G/12.7G [01:47<00:00, 108MB/s]
100% 12.7G/12.7G [01:48<00:00, 126MB/s]


In [6]:
!unzip -qn -d triage triage.zip

In [7]:
!pip install -qqq wandb pytorch-lightning torchmetrics

[K     |████████████████████████████████| 1.9 MB 5.5 MB/s 
[K     |████████████████████████████████| 800 kB 46.8 MB/s 
[K     |████████████████████████████████| 512 kB 51.8 MB/s 
[K     |████████████████████████████████| 182 kB 43.5 MB/s 
[K     |████████████████████████████████| 174 kB 55.4 MB/s 
[K     |████████████████████████████████| 62 kB 984 kB/s 
[K     |████████████████████████████████| 173 kB 63.4 MB/s 
[K     |████████████████████████████████| 168 kB 56.6 MB/s 
[K     |████████████████████████████████| 168 kB 52.7 MB/s 
[K     |████████████████████████████████| 166 kB 61.6 MB/s 
[K     |████████████████████████████████| 166 kB 59.4 MB/s 
[K     |████████████████████████████████| 162 kB 58.7 MB/s 
[K     |████████████████████████████████| 162 kB 68.2 MB/s 
[K     |████████████████████████████████| 158 kB 56.0 MB/s 
[K     |████████████████████████████████| 157 kB 57.9 MB/s 
[K     |████████████████████████████████| 157 kB 54.9 MB/s 
[K     |██████████████████

In [13]:
import os
import collections

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torchmetrics

from PIL import Image
import wandb

from pytorch_lightning.loggers import WandbLogger

from tqdm import tqdm

import functools

def debug(func):
    """A decorator that prints the function name, argument names and argument values"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # Get the function's argument names
        arg_names = func.__code__.co_varnames[:func.__code__.co_argcount]
        # Zip the argument names and values
        args_dict = dict(zip(arg_names, args))
        # Update the dictionary with the keyword arguments
        args_dict.update(kwargs)
        # Print the function name and arguments
        print(f"{func.__name__}({args_dict})")
        # Call the function
        result = func(*args, **kwargs)
        # Return the result
        return result
    return wrapper

pl.seed_everything(42)
wandb.login()

class SPSSiamese(pl.LightningModule):
  def __init__(self, backbone='resnet18', lr=1e-4):
    super().__init__()

    if backbone not in models.__dict__:
      raise Exception(f'No model named {backbone} exists in torchvision.models')

    self.backbone = models.__dict__[backbone](pretrained=True)
    for param in self.backbone.parameters():
      param.requires_grad = False

    num_ftrs = self.backbone.fc.in_features
    self.backbone = torch.nn.Sequential(*(list(self.backbone.children())[:-1]))

    self.combine = nn.Sequential(
       nn.Linear(num_ftrs * 2, 256),
       nn.Tanh(),
       nn.Linear(256, 128),
       nn.Tanh(),
       nn.Linear(128, 2)
    )

    self.save_hyperparameters()

    # compute the accuracy -- no need to roll your own!
    self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
    self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
    self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)

  def forward_once(self, x):
    output = self.backbone(x)
    output = output.view(output.size()[0], -1)
    return output
  

  def forward(self, xa, xb):
    batch_size, *dims = xa.size()

    # stem: split

    # learner: two branches
    output1 = self.forward_once(xa)
    output2 = self.forward_once(xb)

    # concatenate the output of the two branches
    # TODO: concat vs. subtract vs. multiplication vs. absolute difference
    combined_output = torch.cat([output1, output2], dim=1)

    final_output = self.combine(combined_output)

    # task: compute logits
    x = F.log_softmax(final_output, dim=1)

    return x

  def loss(self, xa, xb, ys):
    logits = self(xa, xb)
    loss = F.nll_loss(logits, ys)
    return logits, loss

  def training_step(self, batch, batch_idx):
    xa, xb, ys = batch
    logits, loss = self.loss(xa, xb, ys)
    preds = torch.argmax(logits, 1)

    # logging metrics we calculated by hand
    self.log('train/loss', loss, on_epoch=True)
    # logging a pl.Metric
    self.train_acc(preds, ys)
    self.log('train/acc', self.train_acc, on_epoch=True)

    return loss

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])

  def test_step(self, batch, batch_idx):
    xa, xb, ys = batch
    logits, loss = self.loss(xa, xb, ys)
    preds = torch.argmax(logits, 1)

    self.test_acc(preds, ys)
    self.log("test/loss_epoch", loss, on_step=False, on_epoch=True)
    self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True)

  def test_epoch_end(self, test_step_outputs):  # args are defined as part of pl API
    dummy_input = torch.zeros(1, 3, 224, 224, device=self.device)
    model_filename = "model_final.onnx"
    self.to_onnx(model_filename, [dummy_input, dummy_input], export_params=True)
    artifact = wandb.Artifact(name="model.ckpt", type="model")
    artifact.add_file(model_filename)
    wandb.log_artifact(artifact)

  def validation_step(self, batch, batch_idx):
    xa, xb, ys = batch
    logits, loss = self.loss(xa, xb, ys)
    preds = torch.argmax(logits, 1)
    self.valid_acc(preds, ys)

    self.log("valid/loss_epoch", loss)  # default on val/test is on_epoch only
    self.log('valid/acc_epoch', self.valid_acc)
        
    return logits

  def validation_epoch_end(self, validation_step_outputs):
      dummy_input = torch.zeros(1, 3, 224, 224, device=self.device)
      model_filename = f"model_{str(self.global_step).zfill(5)}.onnx"
      torch.onnx.export(self, (dummy_input, dummy_input), model_filename, opset_version=11)
      artifact = wandb.Artifact(name="model.ckpt", type="model")
      artifact.add_file(model_filename)
      self.logger.experiment.log_artifact(artifact)

      flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
      self.logger.experiment.log(
          {"valid/logits": wandb.Histogram(flattened_logits.to("cpu")),
          "global_step": self.global_step})

class ImageCache:
    def __init__(self, transform, maxsize=13000):
        self.cache = collections.OrderedDict()
        self.transform = transform
        self.maxsize = maxsize

    def __getitem__(self, path):
        # Check if the image is in the cache
        if path in self.cache:
            # If it is, move it to the end of the cache so that it is the
            # most recently used image
            self.cache.move_to_end(path)
            return self.cache[path]

        # If the image is not in the cache, load it from disk and add it
        # to the cache
        image = self.transform(Image.open(path).convert('RGB'))
        self.cache[path] = image
        if len(self.cache) > self.maxsize:
            # If the cache has reached its maximum size, remove the oldest
            # item from the cache
            self.cache.popitem(last=False)
        return image
train_cache = ImageCache(transform=transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]))

class TriageDataset():

  def parse_pairlist(self):
    if self.train:
      pairlist = os.path.join(self.data_dir, 'train_val', 'train_pairlist.txt')
    else:
      pairlist = os.path.join(self.data_dir, 'train_val', 'val_pairlist.txt')
    if not os.path.exists(pairlist):
        raise Exception(f'Pairlist file {pairlist} does not exist')

    images_basepath = os.path.join(self.data_dir, 'train_val', 'train_val_imgs')
    if not os.path.exists(images_basepath):
        raise Exception(f'Images base path {images_basepath} does not exist')
    
    with open(pairlist, 'r') as f:
      lines = f.readlines()

    num_lines = len(lines)
    for line in tqdm(lines, total=num_lines):
      line = line.strip()
      if not line:
        continue

      series_id, photoA_idx, photoB_idx, preference, rankA, rankB = line.split()
      pathA = os.path.join(images_basepath, f'{int(series_id):06d}-{int(photoA_idx):02d}.JPG')
      pathB = os.path.join(images_basepath, f'{int(series_id):06d}-{int(photoB_idx):02d}.JPG')
      result = 0 if float(preference) < 0.5 else 1

      yield pathA, pathB, result

  def __init__(self, data_dir='./', train=True, transform=None):
    self.transform = transform
    self.data_dir = data_dir
    self.train = train
    
    pairs = list(self.parse_pairlist())
    self.path_a, self.path_b, self.label = zip(*pairs)
    
  def __len__(self):
    return len(self.label)

  def __getitem__(self, idx):
    image_a = train_cache[self.path_a[idx]]
    image_b = train_cache[self.path_b[idx]]
    #image_a = self.transform(Image.open(self.path_a[idx]).convert('RGB'))
    #image_b = self.transform(Image.open(self.path_b[idx]).convert('RGB'))
    label = self.label[idx]

    return image_a, image_b, label
  

class TriageDataModule(pl.LightningDataModule):
  def __init__(self, data_dir='./', batch_size=512, input_size=224):
    super().__init__()
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    self.save_hyperparameters()

  def setup(self, stage=None):
    if stage == 'fit' or stage is None:      
      self.triage_train = TriageDataset(data_dir=self.data_dir, train=True, transform=self.transform)
      self.triage_val = TriageDataset(data_dir=self.data_dir, train=False, transform=self.transform)
    if stage == 'test' or stage is None:
      self.triage_test = TriageDataset(data_dir=self.data_dir, train=False, transform=self.transform)

  def train_dataloader(self):
    return DataLoader(self.triage_train, batch_size=self.batch_size)
  
  def val_dataloader(self):
    return DataLoader(self.triage_val, batch_size=self.batch_size)

  def test_dataloader(self):
    return DataLoader(self.triage_test, batch_size=self.batch_size)

triage = TriageDataModule(data_dir='/content/triage')
triage.prepare_data()
triage.setup()

samples = next(iter(triage.val_dataloader()))

wandb_logger = WandbLogger(project='triage-wandb')

trainer = pl.Trainer(
    logger=wandb_logger,
    log_every_n_steps=5,
    accelerator="auto",
    max_epochs=50,
    deterministic=True,
    callbacks=[EarlyStopping(monitor="val_loss", mode="min")]
)

model = SPSSiamese(backbone='resnet50')

trainer.fit(model, triage)

trainer.test(datamodule=triage,
             ckpt_path=None)

wandb.finish()    

INFO:lightning_lite.utilities.seed:Global seed set to 42

100%|██████████| 12075/12075 [00:00<00:00, 146990.24it/s]

100%|██████████| 483/483 [00:00<00:00, 50908.40it/s]

100%|██████████| 483/483 [00:00<00:00, 85772.00it/s]
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs

  0%|          | 0/12075 [00:00<?, ?it/s][A
100%|██████████| 12075/12075 [00:00<00:00, 88158.42it/s]

100%|██████████| 483/483 [00:00<00:00, 74419.54it/s]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type               | Params
-------------------------------------------------
0 | backbone  | Sequential         | 23.5 M
1 | combine   | Sequential         | 1.1 M 
2 | train_acc | MulticlassAccuracy | 0     
3 | valid_acc | Multiclass

Sanity Checking: 0it [00:00, ?it/s]

RuntimeError: ignored