In [None]:
!nvidia-smi

# Colab Setup

In [None]:
!git clone https://github.com/pskchai/food-depth-dpt.git

In [None]:
!wget https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt -P /content/food-depth-dpt/weights/

In [None]:
!mkdir /content/food-depth-dpt/data/
!cp '/content/drive/Shareddrives/Food Analytic/Data/nutrition5k_lite.zip' /content/food-depth-dpt/data/
%cd /content/food-depth-dpt/data/
!unzip -qq nutrition5k_lite.zip
!rm nutrition5k_lite.zip
%cd -

In [None]:
%pip install -qqq -r /content/food-depth-dpt/requirements.txt
%pip install -qqq torchinfo

# Preparation
Import libraries and set configurations

In [None]:
%cd /content/food-depth-dpt/

import os
import time
from datetime import datetime
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from torchinfo import summary
from finetune.models import DPTModule

In [None]:
config = {
    'base_scale': 0.0000305,
    'base_shift': 0.1378,
    'batch_size': 16,
    'image_size': (384, 384),
    'base_lr': 1e-6,
    'max_lr': 1e-5,
    'num_epochs' : 70,
    'early_stopping_patience': 10,
    'num_workers': 2,
    'model_path': '/content/food-depth-dpt/weights/dpt_hybrid-midas-501f0c75.pt',
    'dataset_path': '/content/food-depth-dpt/data/nutrition5k/',
    'weights_save_path': '/content/drive/Shareddrives/Food Analytic/models/DPT/',
    'logs_save_path': '/content/drive/Shareddrives/Food Analytic/models/DPT/',
    'checkpoint_path': '/content/drive/Shareddrives/Food Analytic/models/DPT/lightning_logs/version_1/checkpoints/epoch=57-step=9976.ckpt',
}

pl.seed_everything(42)

# Tensorboard

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir "{config['logs_save_path']}"

# Training
Create and train the DPT model by using pytorch lightning.

In [None]:
model = DPTModule(
    model_path=config['model_path'],
    dataset_path=config['dataset_path'],
    scale=config['base_scale'],
    shift=config['base_shift'],
    batch_size=config['batch_size'],
    base_lr=config['base_lr'],
    max_lr=config['max_lr'],
    num_workers=config['num_workers'],
    image_size=config['image_size'],
)

summary(model.model, input_size=(1, 3, config['image_size'][0], config['image_size'][1]))

In [None]:
logger = pl.loggers.TensorBoardLogger(
    save_dir=config['logs_save_path'],
)

lr_monitor = pl.callbacks.LearningRateMonitor()
early_stopping = pl.callbacks.EarlyStopping(monitor="val_loss", patience=config['early_stopping_patience'])

trainer = pl.Trainer(
    devices='auto',
    accelerator='auto',
    max_epochs=config['num_epochs'],
    logger=logger,
    callbacks=[lr_monitor, early_stopping],
    weights_save_path=config['weights_save_path']
)

trainer.fit(model, ckpt_path=config['checkpoint_path'])

After the training is successfully finished, run the following cell to extract the DPT weights from pytorch lightning module.

In [None]:
module_name = 'lightning_logs'
latest_version = sorted(os.listdir(os.path.join(config['weights_save_path'], module_name)))[-1]

checkpoint_base_path = os.path.join(config['weights_save_path'], module_name, latest_version, 'checkpoints')
saved_model_base_path = '/content/drive/Shareddrives/Food Analytic/models/DPT/state_dict'

checkpoint_filename = sorted(os.listdir(checkpoint_base_path))[-1]
saved_model_filename = f'{datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}.pt'
loaded_module = model.load_from_checkpoint(
    os.path.join(checkpoint_base_path, checkpoint_filename),
    model_path=config['model_path'],
    dataset_path=config['dataset_path'],
    scale=config['base_scale'],
    shift=config['base_shift'],
    batch_size=config['batch_size'],
    base_lr=config['base_lr'],
    max_lr=config['max_lr'],
    num_workers=config['num_workers'],
    image_size=config['image_size'],
)

model_state_dict = loaded_module.model.state_dict()

torch.save(model_state_dict, os.path.join(saved_model_base_path, saved_model_filename))