In [27]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
import math
from tqdm.notebook import tqdm

import raster_relight as rr
import raster_dataloader as rd

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


In [3]:
# Model
class LightMPL(torchvision.ops.MLP):
    def __init__(self, num_feats, hidden_channels, norm_layer=None, activation_layer=torch.nn.modules.activation.ReLU, inplace=None, bias=True, dropout=0.0):
        super().__init__(num_feats * 3, hidden_channels + [3], norm_layer, activation_layer, inplace, bias, dropout)

    def forward(self, x):
        return super().forward(x)


In [9]:
# Get the dataloader
def get_dataloader(batch_size, split='train'):
    is_train = split == 'train'
    "Get a dataloader for training or testing"
    full_dataset = rd.RasterDataset() # uses a config to decide what to load.
    loader = torch.utils.data.DataLoader(dataset=full_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=0)
    return loader


In [20]:
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for step, (feats, target_vector) in enumerate(valid_dl):
            feats, target_vector = feats.to(device), target_vector.to(device)

            # Forward pass
            inputs = torch.flatten(feats, start_dim=1)
            outputs = model(inputs) # should be (batch_len, 3)
            val_loss += loss_func(outputs, target_vector)*target_vector.size(0)

            # Log one batch of images to the dashboard, always same batch_idx.
            # TODO: Implement image reconstruction using validation image and model
            # if i==batch_idx and log_images:
            #     log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset)


In [5]:
loader = get_dataloader(True, 1024)
len(loader)

Loading dataset from intrinsic_tester_sphere/: 1it [00:00,  2.71it/s]

in dl: feats have shape of torch.Size([2370105, 3, 3])





2315

In [29]:
# Launch 5 experiments, trying different dropout rates

# 🐝 initialise a wandb run
wandb.init(
    project="light-mlp-supervised-cosine",
    config={
        "epochs": 2,
        "batch_size": 1024,
        "lr": 1e-3,
        "dropout": 0.0, # random.uniform(0.01, 0.80),
        "num_feats": 3,
        "hidden_channels" : [256]*4 + [128], 
        })

# Copy your config 
config = wandb.config

# Get the data
train_dl = get_dataloader(batch_size=config.batch_size)
valid_dl = get_dataloader(batch_size=2*config.batch_size, split='valid')

n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

# MLP model
model = LightMPL(config.num_feats, config.hidden_channels, dropout=config.dropout)

# Make the loss and optimizer
cosine = nn.CosineSimilarity(dim=1, eps=1e-6)
# NOTE: don't forget that this needs to be exteded to be to twice the length for validation loop
cosine_target = torch.tensor([1.0]).to(device)
loss_func = lambda x,y: F.mse_loss(cosine(x,y), cosine_target)
#valid_loss_func = lambda x,y: F.mse_loss(cosine(x,y), torch.cat([cosine_target,cosine_target]))

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

# Training
step_ct = 0
for epoch in tqdm(range(config.epochs),
                  desc="Epoch", 
                  total=config.epochs, 
                  position=0, leave=True, colour='green'):
    model.train()
    for step, (feats, target_vector) in tqdm(enumerate(train_dl), 
                                       total=len(train_dl), 
                                       desc=f"Step",
                                       position=1, leave=False, colour='red'):
        # Move to device
        feats, target_vector = feats.to(device), target_vector.to(device)

        # Forward pass
        inputs = torch.flatten(feats, start_dim=1)
        outputs = model(inputs) # should be (batch_len, 3)
        train_loss = loss_func(outputs, target_vector)

        # Optimization step
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Collect metrics
        metrics = {"train/train_loss": train_loss, 
                   "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                   }

        if step + 1 < n_steps_per_epoch:
            # 🐝 Log train metrics to wandb 
            wandb.log(metrics)

        step_ct += 1

    # Validation
    val_loss = validate_model(model, valid_dl, loss_func)

    # 🐝 Log train and validation metrics to wandb
    val_metrics = {"val/val_loss": val_loss}
    wandb.log({**metrics, **val_metrics})

    print(f"Train Loss: {train_loss:.3f} Valid Loss: {val_loss:3f}")

# 🐝 Close your wandb run 
wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113601444473412, max=1.0…

Loading dataset from intrinsic_tester_sphere/: 1it [00:00,  2.73it/s]


in dl: feats have shape of torch.Size([2370105, 3, 3])


Loading dataset from intrinsic_tester_sphere/: 1it [00:00,  2.89it/s]

in dl: feats have shape of torch.Size([2370105, 3, 3])





Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Step:   0%|          | 0/2315 [00:00<?, ?it/s]

  loss_func = lambda x,y: F.mse_loss(cosine(x,y), cosine_target)
  loss_func = lambda x,y: F.mse_loss(cosine(x,y), cosine_target)
  loss_func = lambda x,y: F.mse_loss(cosine(x,y), cosine_target)


Train Loss: 0.076 Valid Loss: 0.087240


Step:   0%|          | 0/2315 [00:00<?, ?it/s]

Train Loss: 0.087 Valid Loss: 0.082285


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▄▃▃▄▂▄▂▃▂▂▂▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▂▂▁▁▂▁▁▂▁▂
val/val_loss,█▁

0,1
train/epoch,2.0
train/train_loss,0.08689
val/val_loss,0.08229
