# Resource
[deepsphere (paper)](https://arxiv.org/abs/1810.12186)
> [deepsphere-pytorch (github)](https://github.com/deepsphere/deepsphere-pytorch)

[Cosmological Parameter Estimation and Inference using Deep Summaries (paper)](https://arxiv.org/abs/2107.09002)
> [cosmo_estimators (github)](https://github.com/jafluri/cosmo_estimators)


In [None]:
import warnings
from typing import Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn

from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
from monai.utils import alias, deprecated_arg, export

In [6]:
"""Example script for running DeepSphere U-Net on reduced AR_TC dataset.
"""


import numpy as np
import torch
from ignite.contrib.handlers.param_scheduler import create_lr_scheduler_with_warmup
from ignite.contrib.handlers.tensorboard_logger import GradsHistHandler, OptimizerParamsHandler, OutputHandler, TensorboardLogger, WeightsHistHandler
from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.handlers import EarlyStopping, TerminateOnNan
from ignite.metrics import EpochMetric
from sklearn.model_selection import train_test_split
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from deepsphere.data.datasets.dataset import ARTCDataset
from deepsphere.data.transforms.transforms import Normalize, Permute, ToTensor
from deepsphere.models.spherical_unet.unet_model import SphericalUNet
from deepsphere.utils.initialization import init_device
from deepsphere.utils.parser import create_parser, parse_config
from deepsphere.utils.stats_extractor import stats_extractor

## load data

In [62]:
import numpy as np
import torch
from ignite.contrib.handlers.param_scheduler import create_lr_scheduler_with_warmup
from ignite.contrib.handlers.tensorboard_logger import GradsHistHandler, OptimizerParamsHandler, OutputHandler, TensorboardLogger, WeightsHistHandler
from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.handlers import EarlyStopping, TerminateOnNan
from ignite.metrics import EpochMetric
from sklearn.model_selection import train_test_split
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from deepsphere.data.transforms.transforms import Normalize, Permute, ToTensor
from deepsphere.utils.initialization import init_dataset_temp, init_device, init_unet_temp
from deepsphere.utils.parser import create_parser, parse_config
from deepsphere.utils.stats_extractor import stats_extractor

In [57]:
import numpy as np

from deepsphere.utils.initialization import init_dataset_temp, init_device, init_unet_temp
from deepsphere.utils.initialization import init_device

import torch
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader,random_split
from torch.utils.tensorboard import SummaryWriter

import os
import warnings

warnings.filterwarnings("ignore")
HOME = os.path.expandvars('$HOME')
dataset_dir=HOME+'/code/shear2convergence/recon_kappa/'

In [1]:


# define global value
def _exec(args):
    return exec(args, globals())

for i in range(997,1000):
    exec("x1_{:03} = np.load(dataset_dir+'gamma1_w_noise/'+'g1_'+ '{:03}' +'.npy')".format(i,i), globals())
    exec("x2_{:03} = np.load(dataset_dir+'gamma2_w_noise/'+'g2_'+ '{:03}' +'.npy')".format(i,i), globals())
    exec("y_{:03} = np.load(dataset_dir+'kappa/'+'kappa_'+ '{:03}' +'.npy')".format(i,i), globals())
    pass


NameError: name 'np' is not defined

In [37]:
torch.set_default_tensor_type(torch.FloatTensor)

In [18]:

class MyDataset(Dataset):
    def __init__(self, gamma1_dir, gamma2_dir, kappa_dir):
        self.gamma1_dir = gamma1_dir
        self.gamma2_dir = gamma2_dir
        self.kappa_dir = kappa_dir
        
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        x1 = np.load(f"{self.gamma1_dir}/g1_{index:03}.npy")
        x2 = np.load(f"{self.gamma2_dir}/g2_{index:03}.npy")
        y = np.load(f"{self.kappa_dir}/kappa_{index:03}.npy")
        x1 = torch.from_numpy(x1).float()
        x2 = torch.from_numpy(x2).float()
        y = torch.from_numpy(y).float()
        
        x = torch.stack((x1,x2))
        y = torch.stack((y,))
        
        sample = {'gamma': x, 'kappa': y}
        
        return sample
    
    def __len__(self):
        filenames = os.listdir(self.gamma1_dir)
        res = 0
        for name in filenames:
            if '.npy' in name:
                res+=1
        return res

gamma1_dir = dataset_dir+'gamma1_w_noise'
gamma2_dir = dataset_dir+'gamma2_w_noise'
kappa_dir = dataset_dir+'kappa'

# # transform = tr_get_means_stdsansforms.Compose([transforms.Resize(64), transforms.ToTensor()])
dataset = MyDataset(gamma1_dir, gamma2_dir, kappa_dir)
# dataloader = DataLoader(dataset, batch_size=5)

In [35]:
dataset[0]

{'gamma': tensor([[ 0.0041,  0.0022,  0.0012,  ..., -0.0071, -0.0034, -0.0002],
         [ 0.0207, -0.0167,  0.0117,  ...,  0.0044, -0.0048, -0.0019]]),
 'kappa': tensor([[ 0.0034,  0.0043, -0.0054,  ..., -0.0086,  0.0011, -0.0017]])}

In [49]:
def _get_means_stds(dataset, key: str):
    channel, length = torch.Tensor(dataset[0][key]).shape
    
    summing = torch.zeros(channel)
    square_summing = torch.zeros(channel)
    total = 0
    
    for idx in range(len(dataset)):
        sample = dataset[idx][key]
        summing += torch.sum(sample, dim=1)
        total += sample.shape[1]
    means = torch.unsqueeze(summing / total, dim=1)
    
    for idx in range(len(dataset)):
        sample = dataset[idx][key]
        square_summing += torch.sum((sample - means) ** 2, dim=1)
    stds = torch.sqrt(square_summing / (total - 1))
    
    return torch.squeeze(means, dim=1).numpy(), stds.numpy()

for key in dataset[0].keys():
    means, stds = _get_means_stds(dataset, key)
    np.save("./tmp_save/means_"+key+".npy", means)
    np.save("./tmp_save/stds_"+key+".npy", stds)
    pass


In [59]:
def prepare_dataloader(gamma1_dir, gamma2_dir, kappa_dir,train_val_ratio,train_test_ratio):
    # read data
    dataset = MyDataset(gamma1_dir, gamma2_dir, kappa_dir)
    
    def _get_means_stds(dataset, key: str):
        channel, length = torch.Tensor(dataset[0][key]).shape

        summing = torch.zeros(channel)
        square_summing = torch.zeros(channel)
        total = 0

        for idx in range(len(dataset)):
            sample = dataset[idx][key]
            summing += torch.sum(sample, dim=1)
            total += sample.shape[1]
        means = torch.unsqueeze(summing / total, dim=1)

        for idx in range(len(dataset)):
            sample = dataset[idx][key]
            square_summing += torch.sum((sample - means) ** 2, dim=1)
        stds = torch.sqrt(square_summing / (total - 1))

        return torch.squeeze(means, dim=1).numpy(), stds.numpy()

    for key in dataset[0].keys():
        means, stds = _get_means_stds(dataset, key)
        np.save("./tmp_save/means_"+key+".npy", means)
        np.save("./tmp_save/stds_"+key+".npy", stds)
        pass
    means_gamma = np.load("./tmp_save/means_gamma.npy")
    means_kappa = np.load("./tmp_save/means_kappa.npy")
    stds_gamma = np.load("./tmp_save/stds_gamma.npy")
    stds_kappa = np.load("./tmp_save/stds_kappa.npy")
    transform_gamma = transforms.Compose([transforms.Normalize(mean=means_gamma, std=stds_gamma)])
    transform_kappa = transforms.Compose([transforms.Normalize(mean=means_kappa, std=stds_kappa)])
    
    num_samples = len(dataset)
    num_train_samples = int(round(num_samples * train_val_ratio))
    num_test_samples = int(round(num_samples * train_test_ratio))
    num_val_samples = num_samples - num_train_samples - num_test_samples
    assert num_val_samples > 0
    
    splits = (num_train_samples, num_val_samples, num_test_samples)
    train_subjects, val_subjects,  test_subjects = random_split(dataset, splits)
    
    dataloader_train = DataLoader(train_subjects, batch_size=32, shuffle=True, num_workers=10)
    dataloader_val = DataLoader(train_subjects, batch_size=32, shuffle=True, num_workers=6)
    dataloader_test = DataLoader(train_subjects, batch_size=32, shuffle=True, num_workers=6)
    
    return dataloader_train, dataloader_val, dataloader_test

In [54]:
from deepsphere.utils.initialization import init_device

In [None]:
init_device()

In [60]:
gamma1_dir = dataset_dir+'gamma1_w_noise/'
gamma2_dir = dataset_dir+'gamma2_w_noise/'
kappa_dir = dataset_dir+'kappa/'

dataloader_train, dataloader_val, dataloader_test = prepare_dataloader(
    gamma1_dir, gamma2_dir, kappa_dir,
    train_val_ratio = 0.2,
    train_test_ratio = 0.1
)

criterion = nn.MSELoss()
learning_rate = 1e-3
length_per_map = np.load(kappa_dir+'kappa_000.npy').shape[0]

unet = SphericalUNet(
    pooling_class = "healpix",
    N = length_per_map,
    depth = 4,
    kernel_size = 3
)

unet, device = init_device(cpu, unet)
optimizer = optim.Adam(unet.parameters(), lr=learning_rate)

NameError: name 'nn' is not defined

In [22]:
def trainer(engine, batch):
    unet.train()
    x, y = batch['gamma'], batch['kappa']
    x = x.to(device)
    y = y.to(device)
    
    y_hat = unet(x)
    
    B, V, C = output.shape
    B_labels, V_labels, C_labels = labels.shape
    output = output.view(B * V, C)
    labels = labels.view(B_labels * V_labels, C_labels).max(1)[1]
    
    loss = criterion(y_hat,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

writer = SummaryWriter(parser_args.tensorboard_path)

engine_train = Engine(trainer)

engine_validate = create_supervised_evaluator(
    model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform
)

engine_train.add_event_handler(Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch)))
engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

@engine_train.on(Events.EPOCH_COMPLETED)
def epoch_validation(engine):
    """Handler to run the validation engine at the end of the train engine's epoch.

    Args:
        engine (ignite.engine): train engine
    """
    print("beginning validation epoch")
    engine_validate.run(dataloader_validation)

reduce_lr_plateau = ReduceLROnPlateau(
    optimizer,
    mode=parser_args.reducelronplateau_mode,
    factor=parser_args.reducelronplateau_factor,
    patience=parser_args.reducelronplateau_patience,
)

@engine_validate.on(Events.EPOCH_COMPLETED)
def update_reduce_on_plateau(engine):
    """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch

    Args:
        engine (ignite.engine): validation engine
    """
    ap = engine.state.metrics["AP"]
    mean_average_precision = np.mean(ap[1:])
    reduce_lr_plateau.step(mean_average_precision)

@engine_validate.on(Events.EPOCH_COMPLETED)
def save_epoch_results(engine):
    """Handler to save the metrics at the end of the validation engine's epoch

    Args:
        engine (ignite.engine): validation engine
    """
    ap = engine.state.metrics["AP"]
    mean_average_precision = np.mean(ap[1:])
    print("Average precisions:", ap)
    print("mAP:", mean_average_precision)
    writer.add_scalars(
        "metrics",
        {"mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1]},
        engine_train.state.epoch,
    )
    writer.close()

step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma)
scheduler = create_lr_scheduler_with_warmup(
    step_scheduler,
    warmup_start_value=parser_args.warmuplr_warmup_start_value,
    warmup_end_value=parser_args.warmuplr_warmup_end_value,
    warmup_duration=parser_args.warmuplr_warmup_duration,
)
engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler)

earlystopper = EarlyStopping(
    patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train
)
engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper)

add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path)

engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs)

torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt")

(3145728,)

In [None]:
def average_precision_compute_fn(y_pred, y_true):
    """Attached function to the custom ignite metric AveragePrecisionMultiLabel

    Args:
        y_pred (:obj:`torch.Tensor`): model predictions
        y_true (:obj:`torch.Tensor`): ground truths

    Raises:
        RuntimeError: Indicates that sklearn should be installed by the user.

    Returns:
        :obj:`numpy.array`: average precision vector.
                            Of the same length as the number of labels present in the data
    """
    try:
        from sklearn.metrics import average_precision_score
    except ImportError:
        raise RuntimeError("This metric requires sklearn to be installed.")

    ap = average_precision_score(y_true.numpy(), y_pred.numpy(), None)
    return ap


# Pylint and Ignite incompatibilities:
# pylint: disable=W0612
# pylint: disable=W0613


def validate_output_transform(x, y, y_pred):
    """A transform to format the output of the supervised evaluator before calculating the metric

    Args:
        x (:obj:`torch.Tensor`): the input to the model
        y (:obj:`torch.Tensor`): the output of the model
        y_pred (:obj:`torch.Tensor`): the ground truth labels

    Returns:
        (:obj:`torch.Tensor`, :obj:`torch.Tensor`): model predictions and ground truths reformatted
    """
    output = y_pred
    labels = y
    B, V, C = output.shape
    B_labels, V_labels, C_labels = labels.shape
    output = output.view(B * V, C)
    labels = labels.view(B_labels * V_labels, C_labels)
    return output, labels


def add_tensorboard(engine_train, optimizer, model, log_dir):
    """Creates an ignite logger object and adds training elements such as weight and gradient histograms

    Args:
        engine_train (:obj:`ignite.engine`): the train engine to attach to the logger
        optimizer (:obj:`torch.optim`): the model's optimizer
        model (:obj:`torch.nn.Module`): the model being trained
        log_dir (string): path to where tensorboard data should be saved
    """
    # Create a logger
    tb_logger = TensorboardLogger(log_dir=log_dir)

    # Attach the logger to the trainer to log training loss at each iteration
    tb_logger.attach(
        engine_train, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {"loss": loss}), event_name=Events.ITERATION_COMPLETED
    )

    # Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
    tb_logger.attach(engine_train, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_COMPLETED)

    # Attach the logger to the trainer to log model's weights as a histogram after each epoch
    tb_logger.attach(engine_train, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED)

    # Attach the logger to the trainer to log model's gradients as a histogram after each epoch
    tb_logger.attach(engine_train, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED)

    tb_logger.close()


def get_dataloaders(parser_args):
    """Creates the datasets and the corresponding dataloaders

    Args:
        parser_args (dict): parsed arguments

    Returns:
        (:obj:`torch.utils.data.dataloader`, :obj:`torch.utils.data.dataloader`): train, validation dataloaders
    """

    path_to_data = parser_args.path_to_data
    download = parser_args.download
    partition = parser_args.partition
    seed = parser_args.seed
    means_path = parser_args.means_path
    stds_path = parser_args.stds_path

    data = ARTCDataset(path=path_to_data, download=download, indices=None, transform_data=None, transform_labels=None)

    train_indices, temp = train_test_split(data.indices, train_size=partition[0], random_state=seed)
    val_indices, _ = train_test_split(temp, test_size=partition[2] / (partition[1] + partition[2]), random_state=seed)

    if (means_path is None) or (stds_path is None):
        transform_data_stats = transforms.Compose([ToTensor()])
        train_set_stats = ARTCDataset(
            path=path_to_data, download=download, indices=train_indices, transform_data=transform_data_stats, transform_labels=None
        )
        means, stds = stats_extractor(train_set_stats)
        np.save("./means.npy", means)
        np.save("./stds.npy", stds)
    else:
        try:
            means = np.load(means_path)
            stds = np.load(stds_path)
        except ValueError:
            print("No means or stds were provided. Or path names incorrect.")

    transform_data = transforms.Compose([ToTensor(), Permute(), Normalize(mean=means, std=stds)])
    transform_labels = transforms.Compose([ToTensor(), Permute()])
    train_set = ARTCDataset(
        path=path_to_data, download=download, indices=train_indices, transform_data=transform_data, transform_labels=transform_labels
    )
    validation_set = ARTCDataset(
        path=path_to_data, download=download, indices=val_indices, transform_data=transform_data, transform_labels=transform_labels
    )

    dataloader_train = DataLoader(train_set, batch_size=parser_args.batch_size, shuffle=True, num_workers=12)
    dataloader_validation = DataLoader(validation_set, batch_size=parser_args.batch_size, shuffle=False, num_workers=12)
    return dataloader_train, dataloader_validation


def main(parser_args):
    """Main function to create trainer engine, add handlers to train and validation engines.
    Then runs train engine to perform training and validation.

    Args:
        parser_args (dict): parsed arguments
    """
    dataloader_train, dataloader_validation = get_dataloaders(parser_args)
    criterion = nn.CrossEntropyLoss()

    unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels, parser_args.depth, parser_args.laplacian_type, parser_args.kernel_size)
    unet, device = init_device(parser_args.device, unet)
    lr = parser_args.learning_rate
    optimizer = optim.Adam(unet.parameters(), lr=lr)

    def trainer(engine, batch):
        """Train Function to define train engine.
        Called for every batch of the train engine, for each epoch.

        Args:
            engine (ignite.engine): train engine
            batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader

        Returns:
            :obj:`torch.tensor` : train loss for that batch and epoch
        """
        unet.train()
        data, labels = batch
        labels = labels.to(device)
        data = data.to(device)
        output = unet(data)

        B, V, C = output.shape
        B_labels, V_labels, C_labels = labels.shape
        output = output.view(B * V, C)
        labels = labels.view(B_labels * V_labels, C_labels).max(1)[1]

        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()

    writer = SummaryWriter(parser_args.tensorboard_path)

    engine_train = Engine(trainer)

    engine_validate = create_supervised_evaluator(
        model=unet, metrics={"AP": EpochMetric(average_precision_compute_fn)}, device=device, output_transform=validate_output_transform
    )

    engine_train.add_event_handler(Events.EPOCH_STARTED, lambda x: print("Starting Epoch: {}".format(x.state.epoch)))
    engine_train.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    @engine_train.on(Events.EPOCH_COMPLETED)
    def epoch_validation(engine):
        """Handler to run the validation engine at the end of the train engine's epoch.

        Args:
            engine (ignite.engine): train engine
        """
        print("beginning validation epoch")
        engine_validate.run(dataloader_validation)

    reduce_lr_plateau = ReduceLROnPlateau(
        optimizer,
        mode=parser_args.reducelronplateau_mode,
        factor=parser_args.reducelronplateau_factor,
        patience=parser_args.reducelronplateau_patience,
    )

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def update_reduce_on_plateau(engine):
        """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        reduce_lr_plateau.step(mean_average_precision)

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def save_epoch_results(engine):
        """Handler to save the metrics at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        print("Average precisions:", ap)
        print("mAP:", mean_average_precision)
        writer.add_scalars(
            "metrics",
            {"mean average precision (AR+TC)": mean_average_precision, "AR average precision": ap[2], "TC average precision": ap[1]},
            engine_train.state.epoch,
        )
        writer.close()

    step_scheduler = StepLR(optimizer, step_size=parser_args.steplr_step_size, gamma=parser_args.steplr_gamma)
    scheduler = create_lr_scheduler_with_warmup(
        step_scheduler,
        warmup_start_value=parser_args.warmuplr_warmup_start_value,
        warmup_end_value=parser_args.warmuplr_warmup_end_value,
        warmup_duration=parser_args.warmuplr_warmup_duration,
    )
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler)

    earlystopper = EarlyStopping(
        patience=parser_args.earlystopping_patience, score_function=lambda x: -x.state.metrics["AP"][1], trainer=engine_train
    )
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper)

    add_tensorboard(engine_train, optimizer, unet, log_dir=parser_args.tensorboard_path)

    engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs)

    torch.save(unet.state_dict(), parser_args.model_save_path + "unet_state.pt")


if __name__ == "__main__":
    PARSER_ARGS = parse_config(create_parser())
    main(PARSER_ARGS)