# Self-Supervision

- Train a deep network using the self-supervised learning framework.

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import hydra
from omegaconf import OmegaConf
import wandb
import pprint

# custom package
from run_train import check_device_env
from run_ssl_train import prepare_and_run_ssl_train

---

## Specify the dataset, model, and train setting

In [None]:
project = 'caueeg-ssl'
data_cfg_file = 'caueeg-dementia'
train_cfg_file = 'base_train'
ssl_cfg_file = 'byol'
model_cfg_file = '2D-Conformer-9-512'
device = 'cuda:0'

---

## Initializing configurations using Hydra

In [None]:
with hydra.initialize(config_path='../config'):
    add_configs = [f"data={data_cfg_file}", 
                   f"++data.seq_length=4000",
                   f"++data.input_norm=datapoint",
                   f"++data.awgn=0.05",
                   f"++data.mgn=0.05",
                   f"++data.dropout=0.3",
                   f"++data.channel_dropout=0.2",
                   f"++data.crop_multiple=2",
                   f"model={model_cfg_file}",
                   # f"++model.minibatch_3090=512",
                   f"++model.criterion=multi-bce",
                   f"train={train_cfg_file}",
                   f"+train.device={device}", 
                   f"+train.project={project}",
                   f"++train.lr_scheduler_type=cosine_decay_with_warmup_half",
                   f"++train.total_samples=1e+8",
                   f"++train.save_model=True",
                   f"ssl={ssl_cfg_file}",
                  ]
    
    cfg = hydra.compose(config_name='default', overrides=add_configs)
    
config = {**OmegaConf.to_container(cfg.data), 
          **OmegaConf.to_container(cfg.train),
          **OmegaConf.to_container(cfg.model),
          **OmegaConf.to_container(cfg.ssl)}

check_device_env(config)
pprint.pprint(config)

## Train

In [None]:
%%time
prepare_and_run_ssl_train(rank=None, world_size=None, config=config)

In [1]:
import torch

In [8]:
def get_sine_cosine_positional_embedding(seq_len, dim, class_token=False):
    if dim % 2 != 0:
        raise ValueError("get_sine_cosine_positional_embedding(dim): dim is not multiple of 2.")

    omega = torch.arange(dim // 2, dtype=torch.float)
    omega /= dim / 2.0
    omega = 1.0 / 10000**omega

    position = torch.arange(seq_len, dtype=torch.float)
    product = torch.einsum("l,d->ld", position, omega)

    embedding_sine = torch.sin(product)
    embedding_cosine = torch.cos(product)
    embedding = torch.concatenate([embedding_sine, embedding_cosine], dim=1)

    if class_token:
        embedding = torch.concatenate([torch.zeros((1, dim)), embedding], dim=0)

    return embedding

In [69]:
N, C, L = 4, 5, 6
mask_ratio = 0.5
L_keep = round(L * (1 - mask_ratio))

# random sampling and sorting for masking
random_noise = torch.rand((N, 1, L))
random_shuffle = torch.argsort(random_noise, dim=2)
idx_origin = torch.argsort(random_shuffle, dim=2)
idx_keep = random_shuffle[:, :, :L_keep]

# masking
x = torch.arange(4*5*6).reshape(4, 5, 6)
x_masked = torch.gather(x, dim=2, index=idx_keep.repeat(1, C, 1))

mask = torch.ones((N, C, L), device=x.device)
mask[:, :, :L_keep] = 0
mask = torch.gather(mask, dim=2, index=idx_origin.repeat(1, C, 1))

In [70]:
random_noise

tensor([[[0.2052, 0.0371, 0.7387, 0.3624, 0.7959, 0.5146]],

        [[0.6499, 0.6601, 0.0834, 0.4661, 0.7708, 0.2395]],

        [[0.5109, 0.4014, 0.0619, 0.0332, 0.2876, 0.0418]],

        [[0.9235, 0.4293, 0.3096, 0.3839, 0.9873, 0.7749]]])

In [71]:
random_shuffle

tensor([[[1, 0, 3, 5, 2, 4]],

        [[2, 5, 3, 0, 1, 4]],

        [[3, 5, 2, 4, 1, 0]],

        [[2, 3, 1, 5, 0, 4]]])

In [72]:
idx_origin

tensor([[[1, 0, 4, 2, 5, 3]],

        [[3, 4, 0, 2, 5, 1]],

        [[5, 4, 2, 0, 3, 1]],

        [[4, 2, 0, 1, 5, 3]]])

In [73]:
idx_keep

tensor([[[1, 0, 3]],

        [[2, 5, 3]],

        [[3, 5, 2]],

        [[2, 3, 1]]])

In [74]:
mask

tensor([[[0., 0., 1., 0., 1., 1.],
         [0., 0., 1., 0., 1., 1.],
         [0., 0., 1., 0., 1., 1.],
         [0., 0., 1., 0., 1., 1.],
         [0., 0., 1., 0., 1., 1.]],

        [[1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.]],

        [[1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.],
         [1., 1., 0., 0., 1., 0.]],

        [[1., 0., 0., 0., 1., 1.],
         [1., 0., 0., 0., 1., 1.],
         [1., 0., 0., 0., 1., 1.],
         [1., 0., 0., 0., 1., 1.],
         [1., 0., 0., 0., 1., 1.]]])