<a href="https://colab.research.google.com/github/emmanuel-nwogu/prefit_image_fitting/blob/master/pretraining_image_fitting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [None]:
import random
import sys
import time

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# These most likely have to be dragged and dropped into colab from https://github.com/emmanuel-nwogu/prefit_image_fitting
from cifair import TriplesCiFAIR10
from modules import Siren
from utils import Params, FittingMode, get_mgrid, fit_one_image, save_run_info

SEED_FOR_DATASET = 42  # For data loading order sake.

In [None]:
torch.cuda.empty_cache()

In [None]:
cifair10_save_folder = r"data/cifair10"
output_dir = r"/content/gdrive/MyDrive/ImageFitting/run_output"

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

hls, start_hl_index = [2, 3, 4, 5, 6, 7, 8], 0
hfs, start_hf_index = [32, 64, 128, 256, 512, 1024], 0
assert start_hl_index < len(hls)
assert start_hf_index < len(hfs)
start_run_index = 0
for hl_index, hidden_layers in enumerate(hls):
    for hf_index, hidden_features in enumerate(hfs):
        experiment_name = f"{hidden_layers}hls_{hidden_features}hfs"
        current_exp_save_folder = f"{output_dir}/{experiment_name}"
        prev_exps = len(hfs) * hl_index + hf_index
        num_exps_to_skip = len(hfs) * start_hl_index + start_hf_index
        if prev_exps < num_exps_to_skip:
            print(f"Skipping {experiment_name}...")
            continue
        start_time = time.time()
        params = Params(hidden_layers=hidden_layers, hidden_features=hidden_features, learning_rate=1e-4,
                        batch_size=1, fit_epochs=2000, num_triples=20)
        params.save_json(save_path=current_exp_save_folder)

        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        triples_dataset = TriplesCiFAIR10(determinism_seed=SEED_FOR_DATASET, root=cifair10_save_folder, train=True,
                                          download=True, transform=transform, num_triples=params.num_triples)

        triples_dataloader = DataLoader(triples_dataset, shuffle=True, batch_size=params.batch_size,
                                        pin_memory=True,
                                        num_workers=0)

        # Three images in `images` from dataloader
        # 0 -> random image of class C0, A.
        # 1 -> random image of class C1, B where C1 != C0 and A != B. Last conditional is redundant but still :)
        # 2 -> another random image of class C0, C where C != A.

        # coords is always the same, so it seems memory-inefficient to store it in every batch.
        coords = get_mgrid(sidelen=32, dim=2)
        is_first_run_in_exp = True
        run_seeds = random.sample(range(1, sys.maxsize), params.num_triples)
        print(f"Starting {experiment_name}...")
        for run_index, (images, labels) in enumerate(triples_dataloader):
            if run_index < start_run_index:
                print(f"Skipping {experiment_name}...{run_index}")
                continue
            else:
                start_run_index = -1
            torch.manual_seed(
                run_seeds[run_index])  # The random init for each triple is generated using a unique rand seed.
            labels = labels[0]
            # Save time and avg it? maybe leave this to tqdm? then run a lot of these (>100) on colab asap!
            siren_mlp = Siren(in_features=params.in_features, hidden_features=params.hidden_features,
                              hidden_layers=params.hidden_layers, out_features=params.out_features)
            if is_first_run_in_exp:
                # We don't change is_first_run to True yet since we need it for save_run_info.
                print(siren_mlp)
            siren_mlp = siren_mlp.to(device)
            initial_state_dict = siren_mlp.state_dict()
            for fitting_mode in list(FittingMode):
                if fitting_mode == FittingMode.FIT_FROM_RANDOM_INIT:
                    pass  # no op
                elif fitting_mode == FittingMode.FIT_FROM_MODEL_FIT_TO_IMAGE_FROM_SAME_CLASS:
                    # fit to C
                    siren_mlp.load_state_dict(initial_state_dict)
                    fit_one_image(model=siren_mlp, coords=coords, image_to_fit=images[2], params=params,
                                  device=device,
                                  plot_losses=False, current_fitting_mode=fitting_mode, pretrain=True)
                else:
                    # fit to B
                    siren_mlp.load_state_dict(initial_state_dict)
                    fit_one_image(model=siren_mlp, coords=coords, image_to_fit=images[1], params=params,
                                  device=device,
                                  plot_losses=False, current_fitting_mode=fitting_mode, pretrain=True)

                tqdm_desc = f"{experiment_name}: triple {run_index}/{params.num_triples}, labels {labels.tolist()}, " \
                            f"fit mode {fitting_mode.name}"
                # fit to A
                fitting_metrics = fit_one_image(model=siren_mlp, coords=coords, image_to_fit=images[0],
                                                params=params,
                                                device=device, tqdm_description=tqdm_desc, plot_losses=False,
                                                current_fitting_mode=fitting_mode, pretrain=False,
                                                calculate_psnr=True)
                run_info = {"labels": labels.tolist(), "fitting_mode": fitting_mode.value,
                            "metrics": fitting_metrics}
                save_run_info(run_info, current_exp_save_folder, clear_existing_file=is_first_run_in_exp)

                if is_first_run_in_exp:
                    is_first_run_in_exp = False

        exp_duration = time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))
        print(f"DONE: {experiment_name}. hl_index: {hl_index}, hf_index: {hf_index}. Duration: {exp_duration}")


In [None]:
#@title So I know when to disconnect the runtime lol
from google.colab import output
output.eval_js('new Audio("https://upload.wikimedia.org/wikipedia/commons/0/05/Beep-09.ogg").play()')