In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import math

In [1]:
def show_dataset_samples(dataset, labels, name):
    cols, rows = 4, 4
    figure = plt.figure(figsize=(8, 8))
    for i in range(1, cols * rows + 1):
        sample_index = torch.randint(len(dataset), size=(1,)).item()
        img, _ = dataset[sample_index]
        label = labels[sample_index]
        img = img.permute(1,2,0)
        figure.add_subplot(rows, cols, i)
        plt.suptitle("Plot samples " + name)
        filename = dataset.labels.iloc[i]['plot'][0:-4]
        plt.title(f'{filename}: {label:.2f}m')        
        plt.axis("off")
        plt.imshow(img)
    plt.show()

In [None]:
# Find LR

def find_lr(model, loss_fn, optimizer, dataloader, init_value=1e-8, final_value=10.0):
    number_in_epoch = len(dataloader) - 1
    update_step = (final_value / init_value) ** (1 / number_in_epoch)
    lr = init_value
    optimizer.param_groups[0]["lr"] = lr
    best_loss = 0.0
    batch_num = 0
    losses = []
    log_lrs = []
    for data in dataloader:
        batch_num += 1
        inputs, labels = data
        inputs, labels = inputs, labels.unsqueeze(-1)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)

        # Crash out if loss explodes

        if batch_num > 1 and loss > 4 * best_loss:
            print("exploding loss at batch: ", batch_num)
            return log_lrs, losses

        # Record the best loss

        if loss < best_loss or batch_num == 1:
            best_loss = loss

        # Store the values
        losses.append(loss.item())
        
        log_lr = math.log10(lr)
        log_lrs.append(log_lr)


        # Do the backward pass and optimize

        loss.backward()
        optimizer.step()

        # Update the lr for the next step and store

        lr *= update_step
        optimizer.param_groups[0]["lr"] = lr
        # print(f"loss item {loss.item()} lr {lr}")

    return log_lrs, losses

In [None]:
def find_steepest_point(log_lrs, losses):
    derivatives = np.array([losses[i + 1] - losses[i] for i in range(len(losses) - 1)])
    max_index = np.where(derivatives < 0)[0][np.argmax(np.abs(derivatives[derivatives < 0]))]
    return log_lrs[max_index]

In [2]:
def get_means_stds(dataset):
    tensors = [plot[0] for plot in dataset]

    # Split channels
    channels = torch.chunk(torch.stack(tensors), 3, dim=1)

    means = [torch.mean(channel).item() for channel in channels]
    stds = [torch.std(channel).item() for channel in channels]

    return means, stds