In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
import math
from tqdm.notebook import tqdm
import os

import raster_relight as rr
import raster_dataloader as rd

  warn(


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


In [3]:
# Model
class LightMPL(torchvision.ops.MLP):
    def __init__(self, num_feats, hidden_channels, norm_layer=None, activation_layer=torch.nn.modules.activation.ReLU, inplace=None, bias=True, dropout=0.0):
        super().__init__(num_feats * 3, hidden_channels + [3], norm_layer, activation_layer, inplace, bias, dropout)

    def forward(self, x):
        return super().forward(x)


In [4]:
# Get the dataloader
def get_dataloader(batch_size, split='train'):
    is_train = split == 'train'
    "Get a dataloader for training or testing"
    full_dataset = rd.RasterDataset(split) # uses a config to decide what to load.
    loader = torch.utils.data.DataLoader(dataset=full_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=is_train, 
                                         pin_memory=True, num_workers=0)
    return loader


In [5]:
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for step, (feats, target_vector) in tqdm(enumerate(valid_dl), 
                                                 desc="Validating model", 
                                                 total = len(valid_dl)):
            feats, target_vector = feats.to(device), target_vector.to(device)

            # Forward pass
            inputs = torch.flatten(feats, start_dim=1)
            outputs = model(inputs) # should be (batch_len, 3)
            val_loss += loss_func(outputs, target_vector)*target_vector.size(0)

            # Log one batch of images to the dashboard, always same batch_idx.
            # TODO: Implement image reconstruction using validation image and model
            # if i==batch_idx and log_images:
            #     log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset)


In [6]:
def generate_validation_image(model, valid_dataset):
    """Generate an image comparing a ground truth image with one generated using the model.
    model: MLP which outputs light direction vectors"""
    with torch.inference_mode():
        # Randomly choose which image from the validation set to reconstruct
        image_number = np.random.randint(valid_dataset.num_frames)
        # randomly choose a light in the scene
        light_names = list(valid_dataset._lights_info.keys())
        light_name = light_names[np.random.randint(valid_dataset._num_lights)]
        print(f"Generating OLAT validation image {image_number} with light {light_name}...")
        
        # load attributes of this validation image
        W, H, raster_image_pixels, world_normals, albedo, occupancy_mask = valid_dataset.attributes[image_number][light_name]
        raster_image_pixels = raster_image_pixels.astype(np.float32)
        world_normals = world_normals.astype(np.float32)
        albedo = albedo.astype(np.float32)
        print(f"VALIMG: Albedo had shape {albedo.shape} and type {albedo.dtype}")
        print(f"VALIMG: world_normals had shape {world_normals.shape} and type {world_normals.dtype}")
        print(f"VALIMG: raster_image_pixels had shape {raster_image_pixels.shape} and type {raster_image_pixels.dtype}")
        print(f"VALIMG: occupancy_mask had shape {occupancy_mask.shape} and type {occupancy_mask.dtype}")
        
        
        # prepare inputs for inference
        feats = np.stack([world_normals, albedo, raster_image_pixels], axis=1)
        inputs = torch.flatten(torch.as_tensor(feats).float(), start_dim=1)
        print(f"VALIMG: inputs had shape {inputs.shape} and type {inputs.dtype}")
        
        # Do inference to get light vectors
        light_vectors = model(inputs)
        light_vectors = light_vectors.numpy().astype(np.float32)
        print(f"VALIMG: Inf Light vecs have shape {light_vectors.shape} and type {light_vectors.dtype}")
        
        
        # Construct validation image
        val_image = np.ones((W,H,3))
        val_raster_pixels =  rr.raster_from_directions(light_vectors, albedo, world_normals)
        print(f"VALIMG: val_raster_pixels have shape {val_raster_pixels.shape} and type {val_raster_pixels.dtype}")
        
        val_image[occupancy_mask] = val_raster_pixels
        
        # Construct GT image
        gt_image = np.ones((W,H,3))
        gt_image[occupancy_mask] = raster_image_pixels
        
    return np.concatenate([val_image, gt_image], dim=1)


In [7]:
def train_epoch(epoch, train_dl, model, loss_func):
    cumu_loss = 0.0
    for step, (feats, target_vector) in tqdm(enumerate(train_dl), 
                                       total=len(train_dl), 
                                       desc=f"Step",
                                       position=1, leave=False, colour='red'):
        # Move to device
        feats, target_vector = feats.to(device), target_vector.to(device)
        #print(f"feats size: {feats.shape}")

        # Forward pass
        inputs = torch.flatten(feats, start_dim=1) # (batch_len, )
        outputs = model(inputs) # should be (batch_len, 3)
        train_loss = loss_func(outputs, target_vector)
        cumu_loss += train_loss.item()

        # Optimization step
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Collect metrics
        metrics = {"train/train_loss": train_loss, 
                   "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                   }

        if step + 1 < n_steps_per_epoch:
            # 🐝 Log train metrics to wandb 
            wandb.log(metrics)
        
        break
        
    return cumu_loss / len(train_dl)

In [8]:
# Launch 5 experiments, trying different dropout rates

# 🐝 initialise a wandb run
# NOTE: The model checkpoint path should be to scratch on the cluster

current_run = wandb.init(
    project="light-mlp-supervised-cosine",
    config={
        "epochs": 2,
        "batch_size": 1024,
        "lr": 1e-3,
        "dropout": 0.0, # random.uniform(0.01, 0.80),
        "num_feats": 3,
        "hidden_channels" : [256]*4 + [128], 
        "model_checkpoint_path" : 'model_checkpoints',
        'model_trained_path' : 'model_trained',
        })

# Copy your config 
config = wandb.config

# make output dirs
if not os.path.exists(config.model_checkpoint_path):
    os.makedirs(config.model_checkpoint_path)

if not os.path.exists(config.model_trained_path):
    os.makedirs(config.model_trained_path)

# Get the data
train_dl = get_dataloader(batch_size=config.batch_size)
print("Loaded train dataset.")

valid_dl = get_dataloader(batch_size=2*config.batch_size, split='val')
print("Loaded validation dataset.")

n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

# MLP model
model = LightMPL(config.num_feats, config.hidden_channels, dropout=config.dropout)

# Make the loss and optimizer
cosine = nn.CosineSimilarity(dim=1, eps=1e-6)
#cosine_target = torch.tensor([1.0]).to(device)
def loss_func(x,y):
    cosine_similarity = cosine(x,y)
    similarity_target = torch.tensor([1.0]).broadcast_to(cosine_similarity.size()).to(device)
    return F.mse_loss(cosine_similarity, similarity_target)

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)


[34m[1mwandb[0m: Currently logged in as: [33mdtetruash[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loading dataset from chair_intrinsic/train: 100%|█| 140/140 [00:08<00:00, 17.45it/


Loaded train dataset.


Loading dataset from chair_intrinsic/val: 100%|█| 140/140 [00:07<00:00, 17.90it/s]


Loaded validation dataset.


In [None]:
# Training
for epoch in tqdm(range(config.epochs),
                  desc="Epoch", 
                  total=config.epochs, 
                  position=0, leave=True, colour='green'):
    model.train()
    
    # Train an epoch
    print(f"Training epoch {epoch}")
    avg_loss = train_epoch(epoch, train_dl, model, loss_func)
    wandb.log({"train/avg_loss": avg_loss})

    # Validation
    print(f"Validating model after epoch {epoch}")
    val_loss = validate_model(model, valid_dl, loss_func)
    print("Done validating.")
    
    # Render a validation image
    print("Creating validation image.")
    val_image_array = generate_validation_image(model, valid_dl.dataset)
    val_image = wandb.Image(val_image, caption='Top: Infered Light, Bottom: GT')
    val_image.save(f"validation_image_{epoch:03}.png")

    # 🐝 Log train and validation metrics to wandb
    val_metrics = {"val/val_loss": val_loss,
                   "val/images": val_image}
    wandb.log(val_metrics)

    print(f"Train Loss: {train_loss:.3f} Valid Loss: {val_loss:3f}")

    # save the model and upload it to wandb
    if epoch + 1 == config.epochs:
        # save trained model
        trained_file = f"{config.model_trained_path}/{current_run.project}_{current_run.name}.pth"
        torch.save(model.state_dict(),  trained_file)
        wandb.save(trained_file, policy='now')
    else:
        # save checkpoint
        check_point_file = f"{config.model_checkpoint_path}/{current_run.project}_{current_run.name}_ckpt.pth"
        torch.save(model.state_dict(), check_point_file)
        wandb.save(check_point_file, policy='live')


# 🐝 Close your wandb run 
wandb.finish()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Training epoch 0


Step:   0%|          | 0/9442 [00:00<?, ?it/s]

Validating model after epoch 0


Validating model:   0%|          | 0/4721 [00:00<?, ?it/s]