In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
import math
from tqdm.notebook import tqdm
import os

import raster_relight as rr
import raster_dataloader as rd

  warn(


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


In [3]:
# Model
class LightMPL(torchvision.ops.MLP):
    def __init__(self, num_feats, hidden_channels, norm_layer=None, activation_layer=torch.nn.modules.activation.ReLU, inplace=None, bias=True, dropout=0.0):
        super().__init__(num_feats * 3, hidden_channels + [3], norm_layer, activation_layer, inplace, bias, dropout)

    def forward(self, x):
        x = super().forward(x)
        x = F.normalize(x)
        assert x.shape[-1] == 3
        return x


In [4]:
# Get the dataloader
def get_dataloader(batch_size, split='train'):
    is_train = split == 'train'
    "Get a dataloader for training or testing"
    full_dataset = rd.RasterDataset(split) # uses a config to decide what to load.
    loader = torch.utils.data.DataLoader(dataset=full_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=is_train, 
                                         pin_memory=True, num_workers=0)
    return loader


In [5]:
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for step, (feats, target_vector) in tqdm(enumerate(valid_dl), 
                                                 desc="Validating model", 
                                                 total = len(valid_dl)):
            feats, target_vector = feats.to(device), target_vector.to(device)

            # Forward pass
            inputs = torch.flatten(feats, start_dim=1)
            outputs = model(inputs) # should be (batch_len, 3)
            val_loss += loss_func(outputs, target_vector)*target_vector.size(0)
    
    return val_loss / len(valid_dl.dataset)


In [6]:
def generate_validation_image(model, valid_dataset):
    """Generate an image comparing a ground truth image with one generated using the model.
    model: MLP which outputs light direction vectors"""
    with torch.inference_mode():
        # Randomly choose which image from the validation set to reconstruct
        image_number = np.random.randint(valid_dataset.num_frames)
        # randomly choose a light in the scene
        light_names = list(valid_dataset._lights_info.keys())
        light_name = light_names[np.random.randint(valid_dataset._num_lights)]
        print(f"Generating OLAT validation image {image_number} with light {light_name}...")
        
        # load attributes of this validation image
        W, H, raster_image_pixels, world_normals, albedo, occupancy_mask = valid_dataset.attributes[image_number][light_name]
        raster_image_pixels = raster_image_pixels.astype(np.float32)
        world_normals = world_normals.astype(np.float32)
        albedo = albedo.astype(np.float32)
        # print(f"VALIMG: Albedo had shape {albedo.shape} and type {albedo.dtype}")
        # print(f"VALIMG: world_normals had shape {world_normals.shape} and type {world_normals.dtype}")
        # print(f"VALIMG: raster_image_pixels had shape {raster_image_pixels.shape} and type {raster_image_pixels.dtype}")
        # print(f"VALIMG: occupancy_mask had shape {occupancy_mask.shape} and type {occupancy_mask.dtype}")
        
        
        # prepare inputs for inference
        feats = np.stack([world_normals, albedo, raster_image_pixels], axis=1)
        inputs = torch.flatten(torch.as_tensor(feats).float(), start_dim=1)
        # print(f"VALIMG: inputs had shape {inputs.shape} and type {inputs.dtype}")
        
        # Do inference to get light vectors
        light_vectors = model(inputs)
        light_vectors = light_vectors.numpy().astype(np.float32)
        print(f"inf vectors were: {light_vectors}")
        # print(f"VALIMG: Inf Light vecs have shape {light_vectors.shape} and type {light_vectors.dtype}")
        
        # Construct a normals image
        img_size =(W,H,3)
        val_norm_image = np.ones(img_size)
        val_norms = 0.5*light_vectors + 0.5
        val_norm_image[occupancy_mask] = val_norms
        
        # Construct validation image
        val_image = np.ones(img_size)
        val_raster_pixels =  rr.raster_from_directions(light_vectors, albedo, world_normals)
        # print(f"VALIMG: val_raster_pixels have shape {val_raster_pixels.shape} and type {val_raster_pixels.dtype}")
        
        val_image[occupancy_mask] = val_raster_pixels
        
        # Construct GT image
        gt_image = np.ones(img_size)
        gt_image[occupancy_mask] = raster_image_pixels
        
    return np.concatenate([val_norm_image, val_image, gt_image], axis=1)


In [7]:
def train_epoch(epoch, train_dl, model, loss_func):
    cumu_loss = 0.0
    for step, (feats, target_vector) in tqdm(enumerate(train_dl), 
                                       total=len(train_dl), 
                                       desc=f"Step",
                                       position=1, leave=False, colour='red'):
        # Move to device
        feats, target_vector = feats.to(device), target_vector.to(device)
        #print(f"feats size: {feats.shape}")

        # Forward pass
        inputs = torch.flatten(feats, start_dim=1) # (batch_len, )
        outputs = model(inputs) # should be (batch_len, 3)
        train_loss = loss_func(outputs, target_vector)
        cumu_loss += train_loss.item()

        # Optimization step
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Collect metrics
        metrics = {"train/train_loss": train_loss, 
                   "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                   }

        if step + 1 < n_steps_per_epoch:
            # 🐝 Log train metrics to wandb 
            wandb.log(metrics)
        
    return cumu_loss / len(train_dl)

In [15]:
# Make the loss and optimizer
cosine = nn.CosineSimilarity(dim=1, eps=1e-6)
#cosine_target = torch.tensor([1.0]).to(device)
def loss_func(x,y, lamb=2.0):
    cosine_similarity = cosine(x,y)
    similarity_target = torch.tensor([1.0]).broadcast_to(cosine_similarity.size()).to(device)
    similarity_term = F.mse_loss(cosine_similarity, similarity_target)
    
    x_norms = torch.linalg.norm(x, dim=-1)
    unitarity_term = F.mse_loss(x_norms, 
                                torch.tensor([1.0]).broadcast_to(x_norms.size()).to(device))
    return similarity_term + lamb * unitarity_term


In [9]:
# Get the data
batch_size = 1024
train_dl = get_dataloader(batch_size=batch_size)
print("Loaded train dataset.")

valid_dl = get_dataloader(batch_size=2*batch_size, split='val')
print("Loaded validation dataset.")

Loading dataset from chair_intrinsic/train: 100%|█| 140/140 [00:07<00:00, 19.77it


Loaded train dataset.


Loading dataset from chair_intrinsic/val: 100%|█| 140/140 [00:07<00:00, 19.70it/s


Loaded validation dataset.


In [16]:
# 🐝 initialise a wandb run
# NOTE: The model checkpoint path should be to scratch on the cluster

raster_config = rr.parse_config()

current_run = wandb.init(
    project="light-mlp-supervised-cosine",
    config={
        "epochs": 5,
        "batch_size": batch_size,
        "lr": 1e-3,
        "dropout": 0.0, # random.uniform(0.01, 0.80),
        "num_feats": 3,
        "hidden_channels" : [256]*4 + [128], 
        "model_checkpoint_path" : 'model_checkpoints',
        'model_trained_path' : 'model_trained',
        'scene' : raster_config['paths']['scene']
        })

# Copy your config 
config = wandb.config

# make output dirs
if not os.path.exists(config.model_checkpoint_path):
    os.makedirs(config.model_checkpoint_path)

if not os.path.exists(config.model_trained_path):
    os.makedirs(config.model_trained_path)

n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112789055510398, max=1.0…

In [17]:
# MLP model
model = LightMPL(config.num_feats, config.hidden_channels, dropout=config.dropout)

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

In [19]:
# Training
for epoch in tqdm(range(config.epochs),
                  desc="Epoch", 
                  total=config.epochs, 
                  position=0, leave=True, colour='green'):
    model.train()
    
    # Train an epoch
    print(f"Training epoch {epoch}")
    avg_train_loss = train_epoch(epoch, train_dl, model, loss_func)
    wandb.log({"train/avg_loss": avg_train_loss})

    # Validation
    print(f"Validating model after epoch {epoch}")
    val_loss = validate_model(model, valid_dl, loss_func)
    print("Done validating.")
    
    # Render a validation image
    print("Creating validation image.")
    val_image_array = generate_validation_image(model, valid_dl.dataset)
    #val_image = PILImage.fromarray(val_image_array, mode="RGB")
    val_image = wandb.Image(val_image_array, caption='Top: Infered Light, Bottom: GT')
    #print(f"image type {type(val_image)}")
    #val_image.save(f"validation_image_{epoch:03}.png")

    # 🐝 Log train and validation metrics to wandb
    val_metrics = {"val/val_loss": val_loss,
                   "val/images": val_image}
    wandb.log(val_metrics)

    print(f"Train Loss: {avg_train_loss:.3f} Valid Loss: {val_loss:3f}")

    # save the model and upload it to wandb
    if epoch + 1 == config.epochs:
        # save trained model
        trained_file = f"{config.model_trained_path}/{current_run.project}_{current_run.name}.pth"
        torch.save(model.state_dict(),  trained_file)
        wandb.save(trained_file, policy='now')
    else:
        # save checkpoint
        check_point_file = f"{config.model_checkpoint_path}/{current_run.project}_{current_run.name}_ckpt.pth"
        torch.save(model.state_dict(), check_point_file)
        wandb.save(check_point_file, policy='live')


# 🐝 Close your wandb run 
wandb.finish()

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Training epoch 0


Step:   0%|          | 0/9442 [00:00<?, ?it/s]

Validating model after epoch 0


Validating model:   0%|          | 0/4721 [00:00<?, ?it/s]

Done validating.
Creating validation image.
Generating OLAT validation image 89 with light arealightup...
inf vectors were: [[-0.34993544 -0.02144917  0.9365282 ]
 [-0.35065106 -0.0092955   0.9364601 ]
 [-0.35131702  0.07042827  0.9336039 ]
 ...
 [-0.35270765 -0.01188879  0.93565804]
 [-0.34588248  0.29202062  0.89167786]
 [-0.33980444  0.33004725  0.8806826 ]]
Train Loss: 0.016 Valid Loss: 0.013527
Training epoch 1


Step:   0%|          | 0/9442 [00:00<?, ?it/s]

Validating model after epoch 1


Validating model:   0%|          | 0/4721 [00:00<?, ?it/s]

Done validating.
Creating validation image.
Generating OLAT validation image 54 with light arealightup...
inf vectors were: [[-0.34974736  0.22969356  0.90824974]
 [-0.34942618  0.23124664  0.90797925]
 [-0.34900504  0.2402474   0.90580165]
 ...
 [-0.35244212  0.2161377   0.9105323 ]
 [-0.3432146   0.35053292  0.87139565]
 [-0.34348428  0.32033092  0.8828401 ]]
Train Loss: 0.013 Valid Loss: 0.012518
Training epoch 2


Step:   0%|          | 0/9442 [00:00<?, ?it/s]

Validating model after epoch 2


Validating model:   0%|          | 0/4721 [00:00<?, ?it/s]

Done validating.
Creating validation image.
Generating OLAT validation image 56 with light arealightup...
inf vectors were: [[-0.31687444 -0.36200446  0.876666  ]
 [-0.32871112 -0.2922348   0.8980801 ]
 [-0.34276137 -0.15007924  0.92735696]
 ...
 [-0.3536987   0.1276776   0.9266044 ]
 [-0.3495583   0.24820447  0.9034399 ]
 [-0.3477556   0.32176402  0.880644  ]]
Train Loss: 0.013 Valid Loss: 0.012350
Training epoch 3


Step:   0%|          | 0/9442 [00:00<?, ?it/s]

Validating model after epoch 3


Validating model:   0%|          | 0/4721 [00:00<?, ?it/s]

Done validating.
Creating validation image.
Generating OLAT validation image 37 with light arealightup...
inf vectors were: [[-0.29519707 -0.48159972  0.825179  ]
 [-0.3333502  -0.25963497  0.90634835]
 [-0.333368   -0.2595531   0.9063652 ]
 ...
 [-0.33084112 -0.2756182   0.9025402 ]
 [-0.33129346 -0.27182192  0.90352505]
 [-0.32665643 -0.3049314   0.89460176]]
Train Loss: 0.012 Valid Loss: 0.012238
Training epoch 4


Step:   0%|          | 0/9442 [00:00<?, ?it/s]

Validating model after epoch 4


Validating model:   0%|          | 0/4721 [00:00<?, ?it/s]

Done validating.
Creating validation image.
Generating OLAT validation image 45 with light arealightfront...
inf vectors were: [[-0.3247084  -0.31345576  0.89236194]
 [-0.27555445 -0.5539451   0.78563   ]
 [-0.28484517 -0.5171006   0.8071371 ]
 ...
 [-0.30988857 -0.39521453  0.86473966]
 [-0.33409595 -0.24141121  0.91109854]
 [-0.33602688 -0.22501172  0.91457945]]
Train Loss: 0.012 Valid Loss: 0.012173


0,1
train/avg_loss,█▂▂▁▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▇▃▆▄▃▂▂▄▃▃▂▃▂▂▃▃▂▂▂▂▃▃▃▂▂▁▂▃▂▁▁▁▁▂▂▂▁▃▂
val/val_loss,█▃▂▁▁

0,1
train/avg_loss,0.01216
train/epoch,4.99989
train/train_loss,0.01209
val/val_loss,0.01217
