In [12]:
#!/usr/bin/env python
import torch

from lib.train_dataclasses import TrainConfig
from lib.train_dataclasses import TrainRun
from lib.train_dataclasses import OptimizerConfig
from lib.train_dataclasses import ComputeConfig

from lib.classification_metrics import create_classification_metrics
from lib.data_registry import DataSpiralsConfig
from lib.datasets.spiral_visualization import visualize_spiral
from lib.models.mlp import MLPClassConfig
from lib.generic_ablation import generic_ablation

from lib.distributed_trainer import distributed_train
from lib.ddp import ddp_setup
from lib.files import prepare_results
from lib.render_psql import setup_psql, add_artifact, add_train_run


def create_config(mlp_dim, ensemble_id):
    loss = torch.nn.CrossEntropyLoss()

    def ce_loss(output, batch):
        return loss(output["logits"], batch["target"])

    train_config = TrainConfig(
        model_config=MLPClassConfig(widths=[mlp_dim, mlp_dim]),
        train_data_config=DataSpiralsConfig(seed=0, N=1000),
        val_data_config=DataSpiralsConfig(seed=1, N=500),
        loss=ce_loss,
        optimizer=OptimizerConfig(
            optimizer=torch.optim.Adam, kwargs=dict(weight_decay=0.0001)
        ),
        batch_size=500,
        ensemble_id=ensemble_id,
    )
    train_eval = create_classification_metrics(visualize_spiral, 2)
    train_run = TrainRun(
        compute_config=ComputeConfig(distributed=False, num_workers=1),
        train_config=train_config,
        train_eval=train_eval,
        epochs=1,
        save_nth_epoch=20,
        validate_nth_epoch=20,
        notes=dict(purpose="isolatitude window")
    )
    return train_run


config = create_config(100, 0)
add_train_run(config)
result_path = prepare_results("ring_windows", config)
setup_psql()
    #add_artifact(configs[0], "plot.png", path / "plot.png")


[db] Connection to localhost:5431


In [2]:
import healpix
import chealpix as chp
import numpy as np

NSIDE = 4
n_pixels = healpix.nside2npix(NSIDE)
n_pixels

192

In [3]:
hp = np.zeros((n_pixels,), dtype=np.float32)

In [4]:
polar_idx = list(range(0, NSIDE))
current_idx = 0
north_idxs = []
north_eq_idxs = []
south_eq_idxs = []
south_idxs = []
for window_idx in polar_idx:
    north_idxs.append([ current_idx + i for i in range(4 * (window_idx + 1))])
    current_idx += 4 * (window_idx + 1)

for window_idx in range(NSIDE):
    north_eq_idxs.append([current_idx + i for i in range(4*NSIDE)])
    current_idx += 4*NSIDE

for window_idx in range(NSIDE - 1):
    south_eq_idxs.append([current_idx + i for i in range(4*NSIDE)])
    current_idx += 4*NSIDE

# nside 2, 0 -> 0

# nside 3, 0 -> 1
# nside 3, 1 -> 0
for window in reversed(north_idxs):
    south_idxs.append([n_pixels - 1 - idx for idx in window])
#for window_idx in polar_idx:
#    south_idxs.append([ current_idx + i for i in range(4 * ((NSIDE - 1 - window_idx) + 1))])
#    current_idx += 4 * ((NSIDE - 1 - window_idx) + 1)

In [5]:
north_idxs

[[0, 1, 2, 3],
 [4, 5, 6, 7, 8, 9, 10, 11],
 [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
 [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]]

In [6]:
len(north_eq_idxs + south_eq_idxs)

7

In [7]:
south_idxs

[[167,
  166,
  165,
  164,
  163,
  162,
  161,
  160,
  159,
  158,
  157,
  156,
  155,
  154,
  153,
  152],
 [179, 178, 177, 176, 175, 174, 173, 172, 171, 170, 169, 168],
 [187, 186, 185, 184, 183, 182, 181, 180],
 [191, 190, 189, 188]]

In [8]:
all_windows = north_idxs + north_eq_idxs + south_eq_idxs + south_idxs
colors = np.arange(len(all_windows))
np.random.shuffle(colors)
for idx, window in enumerate(all_windows):
    n_pixels_in_window = len(window)
    n_sub_windows = n_pixels_in_window // 16 + 1
    nest_idxs = chp.ring2nest(NSIDE, window)
    for sub_idx in range(n_sub_windows):
        sub_idxs = nest_idxs[sub_idx::n_sub_windows]
        #print(sub_idxs.shape)
        hp[sub_idxs] =  float(colors[(sub_idx + idx) % len(all_windows)])#float(2*idx % len(all_windows)) + 1

In [9]:
hp

array([ 5., 14.,  5.,  6., 14.,  6.,  7.,  8., 14.,  6.,  7.,  8.,  8.,
       12., 12.,  1.,  5., 14.,  5.,  6., 14.,  6.,  7.,  8., 14.,  6.,
        7.,  8.,  8., 12., 12.,  1.,  5., 14.,  5.,  6., 14.,  6.,  7.,
        8., 14.,  6.,  7.,  8.,  8., 12., 12.,  1.,  5., 14.,  5.,  6.,
       14.,  6.,  7.,  8., 14.,  6.,  7.,  8.,  8., 12., 12.,  1., 10.,
        2., 10.,  3.,  2.,  3.,  4.,  4.,  2.,  3.,  4.,  4.,  5., 14.,
        5.,  6., 10.,  2., 10.,  3.,  2.,  3.,  4.,  4.,  2.,  3.,  4.,
        4.,  5., 14.,  5.,  6., 10.,  2., 10.,  3.,  2.,  3.,  4.,  4.,
        2.,  3.,  4.,  4.,  5., 14.,  5.,  6., 10.,  2., 10.,  3.,  2.,
        3.,  4.,  4.,  2.,  3.,  4.,  4.,  5., 14.,  5.,  6.,  9.,  0.,
        0., 13., 13., 11., 13., 11., 13., 11., 13., 11., 10.,  2., 10.,
        3.,  9.,  0.,  0., 13., 13., 11., 13., 11., 13., 11., 13., 11.,
       10.,  2., 10.,  3.,  9.,  0.,  0., 13., 13., 11., 13., 11., 13.,
       11., 13., 11., 10.,  2., 10.,  3.,  9.,  0.,  0., 13., 13

In [13]:
def save_and_register(name, array):
    path = result_path / f"{name}.npy"

    np.save(
        path,
        array[None, :],
    )
    add_artifact(config, name, path)


In [250]:
save_and_register(f"window_nside_{NSIDE}.npy", hp)

[db] Connection to alvis2:5431
[db] Uploading artifact
[db] Chunk 1
[Database] Added artifact window_nside_1.npy: /mimer/NOBACKUP/groups/naiss2023-6-319/eqp/artifacts/results/ring_windows_git_ea96398_config_53db940/window_nside_1.npy.npy


In [10]:
all_windows = north_idxs + north_eq_idxs + south_eq_idxs + south_idxs
4 * 256 / 64

16.0

In [12]:
sum([4, 8, 12, 16, 16, 16, 16, 16, 16, 16, 16, 16, 12, 8, 4])

192

# With depth

In [129]:
import healpix
import chealpix as chp

def get_isolatitude_windows_hp(nside):
    polar_idx = list(range(0, nside))
    current_idx = 0
    north_idxs = []
    north_eq_idxs = []
    south_eq_idxs = []
    south_idxs = []
    n_pixels = healpix.nside2npix(nside)
    for window_idx in polar_idx:
        north_idxs.append([ current_idx + i for i in range(4 * (window_idx + 1))])
        current_idx += 4 * (window_idx + 1)
    
    for window_idx in range(nside):
        north_eq_idxs.append([current_idx + i for i in range(4*nside)])
        current_idx += 4*nside
    
    for window_idx in range(nside - 1):
        south_eq_idxs.append([current_idx + i for i in range(4*nside)])
        current_idx += 4*nside
    
    # nside 2, 0 -> 0
    
    # nside 3, 0 -> 1
    # nside 3, 1 -> 0
    for window in reversed(north_idxs):
        south_idxs.append([n_pixels - 1 - idx for idx in window])

    return north_idxs + north_eq_idxs + south_eq_idxs + south_idxs

def to_interspersed_windows(nside, max_size, window):
    n_pixels_in_window = len(window)
    n_sub_windows = n_pixels_in_window // max_size + 1
    nest_idxs = chp.ring2nest(nside, window)
    sub_windows = []
    for sub_idx in range(n_sub_windows):
        sub_windows.append(nest_idxs[sub_idx::n_sub_windows].tolist())
    return sub_windows

def flattened_interspersed(nside, max_window_size, windows):
    interspersed = [to_interspersed_windows(nside, max_window_size, window) for window in windows]
    return [ window for subwins in interspersed for window in subwins ]

def pad_windows(max_window_size, windows):
    padded_windows = []
    current_padded = []
    for idx, window in enumerate(windows):
        fits_in_window = len(current_padded) + len(window) <= max_window_size
        if fits_in_window:
            current_padded.extend(window)
        if not fits_in_window:
            current_padded.extend([current_padded[-1]] * (max_window_size - len(current_padded)))
            padded_windows.append(current_padded)
            current_padded = list(window)
        if idx == len(windows) - 1 and len(current_padded) > 0:
            current_padded.extend([current_padded[-1]] * (max_window_size - len(current_padded)))
            padded_windows.append(current_padded)
    return padded_windows

def test_pad_windows(nside):
    max_window_size = 16
    hp_windows = get_isolatitude_windows_hp(nside)
    interspersed = flattened_interspersed(nside, max_window_size, hp_windows)
    padded_windows = pad_windows(max_window_size, interspersed)
    
    data = torch.rand((2, 3, healpix.nside2npix(nside), 48))
    data_pre = data.clone()
    indices = torch.tensor(padded_windows)

    # Extract windows
    windowed = data[:, :, indices, :]

    # Use windows to reconstruct original tensor
    new = torch.zeros(data.shape)
    new[:, :, indices, :] = windowed
    assert (new - data_pre).sum() == 0.0

def window_reverse(windows, window_size, D, N):
    window_size_d, window_size_hp = window_size
    nside = healpix.npix2nside(N)

    hp_windows = get_isolatitude_windows_hp(nside)
    interspersed = flattened_interspersed(nside, window_size_hp, hp_windows)
    padded_windows = pad_windows(window_size_hp, interspersed)

    indices = torch.tensor(padded_windows)    
    
    Nw, W = indices.shape

    B = int(windows.shape[0] / (D * N // (window_size_hp * window_size_d)))
    x = windows.view(
        B, D // window_size_d, Nw, window_size_d, W, -1
    )
    x = x.permute(0, 1, 3, 2, 4, 5)
    
    # B, Nd, Wd, Nw, W, C
    # 0   1   2   3  4  5
    x = x.contiguous().view(B, D, Nw, W, C)
    
    new = torch.zeros(data.shape)
    new[:, :, indices, :] = x

    return new

def window_partition(x: torch.Tensor, window_size):
    window_size_d, window_size_hp = window_size
    
    nside = healpix.npix2nside(x.shape[2])
    hp_windows = get_isolatitude_windows_hp(nside)
    interspersed = flattened_interspersed(nside, window_size_hp, hp_windows)
    padded_windows = pad_windows(window_size_hp, interspersed)

    indices = torch.tensor(padded_windows)    
    windowed = data[:, :, indices, :]

    B, D, Nw, W, C = windowed.shape
    x = windowed.view(B, D // window_size_d, window_size_d, Nw, W, C)
    
    # B, Nd, Wd, Nw, W, C
    # 0   1   2   3  4  5
    
    x = x.permute(0, 1, 3, 2, 4, 5)
    windows = x.contiguous().view(-1, window_size_d * window_size_hp, C)
    return windows

In [136]:
def window_partition_nest(x: torch.Tensor, window_size):
    """
    Args:
        x: (B, D, N, C)
        window_size (int,int): Must be a power of 2 in the healpy grid.

    Returns:
        windows: (num_windows*B, window_size_d * window_size_hp , C)
    """
    # assert that window_size is a power of 2
    # assert (math.log(window_size) / math.log(2)) % 1 == 0

    B, D, N, C = x.shape
    window_size_d, window_size_hp = window_size
    x = x.view(
        B, D // window_size_d, window_size_d, N // window_size_hp, window_size_hp, C
    )
    # B, D//wd, wd, N//whp, whp, c
    # 0  1      2   3       4    5
    # =>
    # B, D//wd, N//whp, wd, whp, c
    # 0  1      3       2   4    5
    x = x.permute(0, 1, 3, 2, 4, 5)
    windows = x.contiguous().view(-1, window_size_d * window_size_hp, C)
    return windows


def window_reverse_nest(windows, window_size, D, N):
    """
    Args:
        windows: (num_windows*B, window_size, C)
        window_size (int): Must be a power of 2 in the healpy grid
        N (int): Number of pixels in the healpy grid

    Returns:
        x: (B, N, C)
    """
    # assert that window_size is a power of 2
    # assert (math.log(window_size) / math.log(2)) % 1 == 0
    window_size_d, window_size_hp = window_size

    B = int(windows.shape[0] / (D * N // (window_size_hp * window_size_d)))
    x = windows.view(
        B, D // window_size_d, N // window_size_hp, window_size_d, window_size_hp, -1
    )
    x = x.permute(0, 1, 3, 2, 4, 5)
    x = x.contiguous().view(B, D, N, -1)
    return x

In [125]:
test_pad_windows(4)

In [138]:
data = torch.rand((2, 8, healpix.nside2npix(256), 48))
windows = window_partition(data, (2, 64))
post = window_reverse(windows, (2, 64), 8, healpix.nside2npix(256))
(post - data).sum()

tensor(0.)

In [121]:
print(data.shape)
print(healpix.nside2npix(4))
#nside = healpix.npix2nside(data.shape[2])
nside = 4
window_size_hp = 16
window_size_d = 2
hp_windows = get_isolatitude_windows_hp(nside)
interspersed = flattened_interspersed(nside, window_size_hp, hp_windows)
padded_windows = pad_windows(window_size_hp, interspersed)


data = torch.rand((2, 8, healpix.nside2npix(4), 48))

N = healpix.nside2npix(4)

data_pre = data.clone()
indices = torch.tensor(padded_windows)
#indices = indices.unsqueeze(0).unsqueeze(0)
windowed = data[:, :, indices, :]

B, D, Nw, W, C = windowed.shape
x = windowed.view(B, D // window_size_d, window_size_d, Nw, W, C)

# B, Nd, Wd, Nw, W, C
# 0   1   2   3  4  5

x = x.permute(0, 1, 3, 2, 4, 5)
windows = x.contiguous().view(-1, window_size_d * window_size_hp, C)
windows.shape


B = int(windows.shape[0] / (D * N // (window_size_hp * window_size_d)))
x = windows.view(
    B, D // window_size_d, Nw, window_size_d, W, -1
)
x = x.permute(0, 1, 3, 2, 4, 5)

# B, Nd, Wd, Nw, W, C
# 0   1   2   3  4  5
x = x.contiguous().view(B, D, Nw, W, C)

new = torch.zeros(data.shape)
new[:, :, indices, :] = x

(new - data_pre).sum()
#x = x.contiguous().view(B, D, N, -1)

#hp_windows[:20]
#padded_windows[:5]
#print(indices.shape)
#for hp_window in padded_windows:
#    tindex = torch.tensor(hp_window)
#    window = torch.index_select(data, 2, tindex)
#    data[:, :, tindex, :] = window

#(data_pre - data).sum()

torch.Size([2, 8, 192, 48])
192


tensor(0.)

In [122]:
windowed.shape

torch.Size([2, 8, 13, 16, 48])

In [123]:
indices.shape

torch.Size([13, 16])

# Data tests

In [2]:
import experiments.weather.models.swin_hp_pangu as shp
import lib.data_factory as data_factory
import experiments.weather.persisted_configs.train_nside64 as c
import experiments.weather.models.swin_hp_pangu as shp
import torch
import experiments.weather.data as data


df = data_factory.get_factory()
config = c.create_config(0)

data.serialize_dataset_statistics(config.train_config.model_config.nside, test_with_one_sample=True)
ds = df.create(config.train_config.train_data_config)
pe = shp.PatchEmbed(config.train_config.model_config, ds.__class__.data_spec(config.train_config.train_data_config))
dl = torch.utils.data.DataLoader(ds, batch_size=1)
batch = next(iter(dl))
surface = batch["input_surface"]
upper = batch["input_upper"]
data = pe(surface, upper)
config.train_config.model_config.nside


Starting with batch size 8
1 samples in 2.20s, 0 left
[92m[Compute environment] paths: 
[92m[Paths] checkpoints: checkpoints (/Users/hampus/projects/equivariant-posteriors/experiments/weather/checkpoints)[0m
[92m[Paths] locks: locks (/Users/hampus/projects/equivariant-posteriors/experiments/weather/locks)[0m
[92m[Paths] distributed_requests: distributed_requests (/Users/hampus/projects/equivariant-posteriors/experiments/weather/distributed_requests)[0m
[92m[Paths] artifacts: artifacts (/Users/hampus/projects/equivariant-posteriors/experiments/weather/artifacts)[0m
[92m[Paths] datasets: datasets (/Users/hampus/projects/equivariant-posteriors/experiments/weather/datasets)[0m[0m
[92m[Compute environment] postgres_host: localhost[0m
[92m[Compute environment] postgres_port: 5431[0m
[92m[Compute environment] postgres_password: herdeherde[0m
Saved npy datasets/driscoll_healy_False_end_year_2017_nside_64_start_year_2007_version_10.npy


64

In [18]:
test = surface.unsqueeze(1).permute(0, 1, 3, 2)
test = torch.concat([test, test], dim=1)
N = test.shape[-2]
test.shape

torch.Size([1, 2, 49152, 4])

In [41]:
import experiments.weather.models.hp_windowing_isolatitude as win
import experiments.weather.models.hp_shifting as shift

In [50]:
test = torch.zeros((1, 2, 12 * 16**2, 1))
N = test.shape[-2]

In [51]:
windows = win.window_partition(test, (2, 16))

In [52]:
windows.shape

torch.Size([239, 32, 1])

In [65]:
import numpy as np
colors = np.arange(windows.shape[0])
#np.random.shuffle(colors)
for idx in range(windows.shape[0]):
    windows[idx, :, :] = 0 #float(colors[idx])
windows[100, :, :] = 1

In [66]:
window_debug = win.window_reverse(windows, (2, 16), 2, N)

In [67]:
shifter = shift.RingShift(nside=16, base_pix=12, window_size=(2, 16), shift_size=2*16, input_resolution=(2, N))
shifted = shifter.shift(window_debug)

windows = win.window_partition(shifted, (2, 16))
windows[100, :, :] = 2

In [68]:
window_debug = win.window_reverse(windows, (2, 16), 2, N)
window_debug = window_debug.permute(0, 1, 3, 2)
window_debug.shape

torch.Size([1, 2, 1, 3072])

In [69]:
save_and_register(f"isolatitude_window_small.npy", window_debug)

[db] Connection to localhost:5431
[db] Uploading artifact
[db] Chunk 1
[Database] Added artifact isolatitude_window_small.npy: artifacts/results/ring_windows_git_3d77aed_config_e4a5d1d/isolatitude_window_small.npy.npy


# Equivariance test

In [70]:
import experiments.weather.models.swin_hp_pangu as swinhppangu
import experiments.weather.models.swin_hp_pangu_isolatitude as swinhppangu_iso