In [None]:
import sys
import kornia
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
from IPython.display import display
import os
import random
import numpy as np
import multiprocessing
import subprocess
torch.set_float32_matmul_precision('medium')
import torch.distributed as dist
from elasticdino.model.elasticdino import ElasticDino


%reload_ext tensorboard



In [None]:
HYPERSIM_PATHS = []
N_HYPERSIM_IMAGES = 6
HYPERSIM_BASE_PATH = "path/to/hypersim"

for path in os.listdir(HYPERSIM_BASE_PATH):
    for subpath in os.listdir(os.path.join(HYPERSIM_BASE_PATH, path, "images")):
        frames = os.listdir(os.path.join(HYPERSIM_BASE_PATH, path, "images", subpath))
        frames = [x for x in frames if "color" in x]
        for f in frames:
            HYPERSIM_PATHS.append(os.path.join(HYPERSIM_BASE_PATH, path, "images", subpath, f))

TRAIN_PROPORTION = 0.8

train_size = int(len(HYPERSIM_PATHS) * TRAIN_PROPORTION)
HYPERSIM_TRAIN_PATHS = HYPERSIM_PATHS[:train_size]
HYPERSIM_TEST_PATHS = HYPERSIM_PATHS[train_size:]


In [3]:

IMAGE_SIZE = 128

def process_image(img):
    l = min(img.height, img.width)
    return img.convert("RGB").crop((0, 0, l, l)).resize((IMAGE_SIZE, IMAGE_SIZE))

    
def hypersim_sample(p):
    folder = os.path.dirname(p)
    f = p.split("/")[-1].split(".")[1]
    
    img = process_image(Image.open(os.path.join(folder, f"frame.{f}.color.jpg")))
    albedo = process_image(Image.open(os.path.join(folder, f"frame.{f}.diffuse_reflectance.jpg")))
    shading = process_image(Image.open(os.path.join(folder, f"frame.{f}.diffuse_illumination.jpg")))
    normal = process_image(Image.open(os.path.join(folder.replace("final", "geometry"), f"frame.{f}.normal_bump_cam.png")))
    img =  torchvision.transforms.functional.pil_to_tensor(img)/255.0
    albedo = torchvision.transforms.functional.pil_to_tensor(albedo)/255.0
    shading = torchvision.transforms.functional.pil_to_tensor(shading)/255.0
    normal = torchvision.transforms.functional.pil_to_tensor(normal)/255.0
    # normal = (2.0 * torchvision.transforms.functional.pil_to_tensor(normal)/255.0) - 1
    return img, albedo, shading, normal

class HypersimDataset(torch.utils.data.Dataset):
    def __init__(self, paths):
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        return hypersim_sample(path)



hypersim_train_ds = HypersimDataset(HYPERSIM_TRAIN_PATHS)
hypersim_test_ds = HypersimDataset(HYPERSIM_TEST_PATHS)


In [4]:
def get_dataloader(dataset, batch_size):
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=32)

In [None]:
#### BITS_AND_BYTES = False
BATCH_SIZE = 16

from accelerate import Accelerator
from accelerate.utils import set_seed, DistributedDataParallelKwargs
from elasticdino.model.layers import ResidualBlock, Activation, ProjectionLayer



class Model(nn.Module):
    def __init__(self, elasticdino):
        super().__init__()
        elasticdino.requires_grad_ = False
        self.elasticdino = elasticdino.eval()
        self.image_encoder = nn.Sequential(
            ProjectionLayer(3, 256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
        )
        self.neck = nn.Sequential(
            ProjectionLayer(1024 + 256, 256),
            ResidualBlock(256),
            ResidualBlock(256),
            ResidualBlock(256),
        )
        def make_head():
            return nn.Sequential(
                ResidualBlock(256),
                ResidualBlock(256),
                nn.Conv2d(256, 128, 1),
                nn.ReLU(),
                nn.Conv2d(128, 64, 1),
                nn.ReLU(),
                nn.Conv2d(64, 3, 1),
            )

        self.albedo_head = make_head()
        self.shading_head = make_head()
        self.normal_head = make_head()

    def forward(self, x):
        with torch.no_grad():
            f = self.elasticdino(x)
        x = self.image_encoder(x)
        f = self.neck(torch.cat([x, f], dim=1))
        return self.albedo_head(f), self.shading_head(f), self.normal_head(f)

    def parameters(self):
        return [*self.neck.parameters(), *self.albedo_head.parameters(), *self.shading_head.parameters(), 
                *self.normal_head.parameters(), *self.image_encoder.parameters()]

    def train(self):
        self.neck.train()
        self.albedo_head.train()
        self.shading_head.train()
        self.normal_head.train()
        self.image_encoder.train()
        
def get_optimizers(model, dataloader, lr, accelerator=None):
  optimizer_class = torch.optim.AdamW
  optimizer = optimizer_class(
[      {"params": model.parameters(), "lr": lr}], eps=1e-5, weight_decay=0.03)

  def lr_lambda(epoch):
    return 1
    # return math.pow(10, - epoch / decay_period)
  scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
  if accelerator is not None:
      model, dataloader, optimizer, scheduler = accelerator.prepare(model, dataloader, optimizer, scheduler)
  return model, dataloader, optimizer, scheduler

import numpy as np

import torchvision
from PIL import Image

def debug_step(batch, results, running_loss, n, display_size):
  gts = []
  preds = []
  with torch.no_grad():
    for gt, pred in zip(batch, results):
        gt = gt[0].permute((1, 2, 0)).detach().cpu().numpy() * 255
        pred = pred[0].clamp(0, 1).permute((1, 2, 0)).detach().cpu().numpy() * 255
        gt = Image.fromarray(gt.astype(np.uint8)).resize((display_size, display_size))
        pred = Image.fromarray(pred.astype(np.uint8)).resize((display_size, display_size))
        gts.append(gt)
        preds.append(pred)
  gts = np.hstack(gts)
  preds = np.hstack(preds)
  res = np.vstack([gts, preds]).astype(np.uint8)
  print(running_loss)
    
  display(Image.fromarray(res))

def compute_loss(x, y):
    return (x - y).square().mean() + kornia.losses.ssim_loss(x, y, 11)
    
def train_parallel(train_config):
  set_seed(42)
  kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
  accelerator = Accelerator(mixed_precision="fp16", kwargs_handlers=[kwargs], dynamo_backend="no")

  n_epochs = train_config.get("n_epochs", 1)
  lr = train_config.get("lr", 1e-4)
  decay_period = train_config.get("decay_period", 5000)
  n_epochs = train_config.get("n_epochs", 1)
  max_iterations = train_config.get("max_iterations", None)
  debug_interval = train_config.get("debug_interval", 50)
  save_interval = train_config.get("save_interval", 1000)
  display_size = train_config.get("display_size", 128)
  batch_size = train_config.get("batch_size", 8)

  ed = ElasticDino.from_pretrained("path/to/edino", "elasticdino-32-L")
  model = Model(ed)
  dataloader = get_dataloader(hypersim_train_ds, batch_size)
    
  model, dataloader, optimizer, scheduler = get_optimizers(model, dataloader, lr, accelerator)
    
  running_loss = None
  n = 0

  print("Start training")
  for epoch in range(n_epochs):
    print("Epoch", epoch)
    for img, albedo, shading, normal in dataloader:
        if n == max_iterations:
          return
        img = img.to(device=accelerator.device)
        albedo = albedo.to(device=accelerator.device)
        shading = shading.to(device=accelerator.device)
        normal = normal.to(device=accelerator.device)
        n += 1
        with accelerator.autocast():
            pred_albedo, pred_shading, pred_normal = model(img)
            loss = compute_loss(albedo, pred_albedo) + compute_loss(normal, pred_normal) + compute_loss(shading, pred_shading)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if running_loss is None:
          running_loss = loss.item()
        else:
          running_loss = 0.98 * running_loss + 0.02 *  loss.item()  
            
        if n % debug_interval == 0 and accelerator.is_local_main_process:
            debug_step([img, albedo, shading, normal], [img, pred_albedo, pred_shading, pred_normal], running_loss, n, display_size)

        del img
        del albedo
        del shading
        del normal
        del loss
        del pred_albedo
        del pred_normal
        del pred_shading
        
from accelerate import notebook_launcher

train_config = dict(
  n_epochs=8,
  # max_iterations=2,
  lr = 1e-4,
  decay_period=5000,
  debug_interval=300,
  save_interval=5,
  display_size=128,
  batch_size=16,
)

args = [train_config]

notebook_launcher(
  train_parallel,
  args,
  num_processes=2
)