In [1]:
#!/usr/bin/env python
import torch

from lib.train_dataclasses import TrainConfig
from lib.train_dataclasses import TrainRun
from lib.train_dataclasses import OptimizerConfig
from lib.train_dataclasses import ComputeConfig

from lib.classification_metrics import create_classification_metrics
from lib.data_registry import DataSpiralsConfig
from lib.datasets.spiral_visualization import visualize_spiral
from lib.models.mlp import MLPClassConfig
from lib.generic_ablation import generic_ablation

from lib.distributed_trainer import distributed_train
from lib.ddp import ddp_setup
from lib.files import prepare_results
from lib.render_psql import setup_psql, add_artifact, add_train_run


def create_config(mlp_dim, ensemble_id):
    loss = torch.nn.CrossEntropyLoss()

    def ce_loss(output, batch):
        return loss(output["logits"], batch["target"])

    train_config = TrainConfig(
        model_config=MLPClassConfig(widths=[mlp_dim, mlp_dim]),
        train_data_config=DataSpiralsConfig(seed=0, N=1000),
        val_data_config=DataSpiralsConfig(seed=1, N=500),
        loss=ce_loss,
        optimizer=OptimizerConfig(
            optimizer=torch.optim.Adam, kwargs=dict(weight_decay=0.0001)
        ),
        batch_size=500,
        ensemble_id=ensemble_id,
    )
    train_eval = create_classification_metrics(visualize_spiral, 2)
    train_run = TrainRun(
        compute_config=ComputeConfig(distributed=False, num_workers=1),
        train_config=train_config,
        train_eval=train_eval,
        epochs=1,
        save_nth_epoch=20,
        validate_nth_epoch=20,
        notes=dict(purpose="ring window")
    )
    return train_run


config = create_config(100, 0)
add_train_run(config)
result_path = prepare_results("ring_windows", config)
setup_psql()
    #add_artifact(configs[0], "plot.png", path / "plot.png")


[Compute environment] Could not load env.py: 
Traceback (most recent call last):
  File "/Users/hampus/projects/equivariant-posteriors/lib/compute_env.py", line 17, in env
    import env
ModuleNotFoundError: No module named 'env'

[Compute environment] Using defaults
[Compute environment] paths: 
[Paths] checkpoints: checkpoints (/Users/hampus/projects/equivariant-posteriors/experiments/weather/checkpoints)
[Paths] locks: locks (/Users/hampus/projects/equivariant-posteriors/experiments/weather/locks)
[Paths] distributed_requests: distributed_requests (/Users/hampus/projects/equivariant-posteriors/experiments/weather/distributed_requests)
[Paths] artifacts: artifacts (/Users/hampus/projects/equivariant-posteriors/experiments/weather/artifacts)
[Paths] datasets: datasets (/Users/hampus/projects/equivariant-posteriors/experiments/weather/datasets)
[Compute environment] postgres_host: localhost
[Compute environment] postgres_port: 5432
[Compute environment] postgres_password: postgres


OperationalError: connection failed: could not receive data from server: Connection refused

In [2]:
import healpix
import chealpix as chp
import numpy as np

NSIDE = 4
n_pixels = healpix.nside2npix(NSIDE)
n_pixels

192

In [3]:
hp = np.zeros((n_pixels,), dtype=np.float32)

In [4]:
polar_idx = list(range(0, NSIDE))
current_idx = 0
north_idxs = []
north_eq_idxs = []
south_eq_idxs = []
south_idxs = []
for window_idx in polar_idx:
    north_idxs.append([ current_idx + i for i in range(4 * (window_idx + 1))])
    current_idx += 4 * (window_idx + 1)

for window_idx in range(NSIDE):
    north_eq_idxs.append([current_idx + i for i in range(4*NSIDE)])
    current_idx += 4*NSIDE

for window_idx in range(NSIDE - 1):
    south_eq_idxs.append([current_idx + i for i in range(4*NSIDE)])
    current_idx += 4*NSIDE

# nside 2, 0 -> 0

# nside 3, 0 -> 1
# nside 3, 1 -> 0
for window in reversed(north_idxs):
    south_idxs.append([n_pixels - 1 - idx for idx in window])
#for window_idx in polar_idx:
#    south_idxs.append([ current_idx + i for i in range(4 * ((NSIDE - 1 - window_idx) + 1))])
#    current_idx += 4 * ((NSIDE - 1 - window_idx) + 1)

In [5]:
north_idxs

[[0, 1, 2, 3],
 [4, 5, 6, 7, 8, 9, 10, 11],
 [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
 [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]]

In [6]:
len(north_eq_idxs + south_eq_idxs)

7

In [7]:
south_idxs

[[167,
  166,
  165,
  164,
  163,
  162,
  161,
  160,
  159,
  158,
  157,
  156,
  155,
  154,
  153,
  152],
 [179, 178, 177, 176, 175, 174, 173, 172, 171, 170, 169, 168],
 [187, 186, 185, 184, 183, 182, 181, 180],
 [191, 190, 189, 188]]

In [14]:
all_windows = north_idxs + north_eq_idxs + south_eq_idxs + south_idxs
colors = np.arange(len(all_windows))
np.random.shuffle(colors)
for idx, window in enumerate(all_windows):
    n_pixels_in_window = len(window)
    n_sub_windows = n_pixels_in_window // 16 + 1
    nest_idxs = chp.ring2nest(NSIDE, window)
    for sub_idx in range(n_sub_windows):
        sub_idxs = nest_idxs[sub_idx::n_sub_windows]
        #print(sub_idxs.shape)
        hp[sub_idxs] =  float(colors[(sub_idx + idx) % len(all_windows)])#float(2*idx % len(all_windows)) + 1

(4,)
(8,)
(12,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(8,)
(12,)
(8,)
(4,)


In [9]:
hp

array([ 4.,  7.,  4.,  2.,  7.,  2., 10., 13.,  7.,  2., 10., 13., 13.,
        8.,  8.,  0.,  4.,  7.,  4.,  2.,  7.,  2., 10., 13.,  7.,  2.,
       10., 13., 13.,  8.,  8.,  0.,  4.,  7.,  4.,  2.,  7.,  2., 10.,
       13.,  7.,  2., 10., 13., 13.,  8.,  8.,  0.,  4.,  7.,  4.,  2.,
        7.,  2., 10., 13.,  7.,  2., 10., 13., 13.,  8.,  8.,  0.,  5.,
        1.,  5., 14.,  1., 14., 11., 11.,  1., 14., 11., 11.,  4.,  7.,
        4.,  2.,  5.,  1.,  5., 14.,  1., 14., 11., 11.,  1., 14., 11.,
       11.,  4.,  7.,  4.,  2.,  5.,  1.,  5., 14.,  1., 14., 11., 11.,
        1., 14., 11., 11.,  4.,  7.,  4.,  2.,  5.,  1.,  5., 14.,  1.,
       14., 11., 11.,  1., 14., 11., 11.,  4.,  7.,  4.,  2.,  6.,  9.,
        9.,  3.,  3., 12.,  3., 12.,  3., 12.,  3., 12.,  5.,  1.,  5.,
       14.,  6.,  9.,  9.,  3.,  3., 12.,  3., 12.,  3., 12.,  3., 12.,
        5.,  1.,  5., 14.,  6.,  9.,  9.,  3.,  3., 12.,  3., 12.,  3.,
       12.,  3., 12.,  5.,  1.,  5., 14.,  6.,  9.,  9.,  3.,  3

In [249]:
def save_and_register(name, array):
    path = result_path / f"{name}.npy"

    np.save(
        path,
        array[None, :],
    )
    add_artifact(config, name, path)


In [250]:
save_and_register(f"window_nside_{NSIDE}.npy", hp)

[db] Connection to alvis2:5431
[db] Uploading artifact
[db] Chunk 1
[Database] Added artifact window_nside_1.npy: /mimer/NOBACKUP/groups/naiss2023-6-319/eqp/artifacts/results/ring_windows_git_ea96398_config_53db940/window_nside_1.npy.npy


In [10]:
all_windows = north_idxs + north_eq_idxs + south_eq_idxs + south_idxs
4 * 256 / 64

16.0

In [12]:
sum([4, 8, 12, 16, 16, 16, 16, 16, 16, 16, 16, 16, 12, 8, 4])

192

# With depth

In [18]:
def get_isolatitude_windows_hp(nside):
    polar_idx = list(range(0, nside))
    current_idx = 0
    north_idxs = []
    north_eq_idxs = []
    south_eq_idxs = []
    south_idxs = []
    for window_idx in polar_idx:
        north_idxs.append([ current_idx + i for i in range(4 * (window_idx + 1))])
        current_idx += 4 * (window_idx + 1)
    
    for window_idx in range(nside):
        north_eq_idxs.append([current_idx + i for i in range(4*nside)])
        current_idx += 4*nside
    
    for window_idx in range(nside - 1):
        south_eq_idxs.append([current_idx + i for i in range(4*nside)])
        current_idx += 4*nside
    
    # nside 2, 0 -> 0
    
    # nside 3, 0 -> 1
    # nside 3, 1 -> 0
    for window in reversed(north_idxs):
        south_idxs.append([n_pixels - 1 - idx for idx in window])

    return north_idxs + north_eq_idxs + south_eq_idxs + south_idxs

In [20]:
test = torch.zeros((1, 13
hp_windows = get_isolatitude_windows_hp(4)
for hp_window in hp_windows:
    for depth in zip(