In [1]:
from ray.tune import Trainable
from ray.tune.utils import wait_for_gpu
import torch
import sys

import json

# print(f"sys path in experiment: {sys.path}")
from pathlib import Path

from shrp.datasets.dataset_simclr import SimCLRDataset

# import model_definitions
from shrp.models.def_AE_module import AEModule

from torch.utils.data import DataLoader

from ray.air.integrations.wandb import setup_wandb

import logging

from shrp.datasets.augmentations import (
    AugmentationPipeline,
    TwoViewSplit,
    WindowCutter,
    ErasingAugmentation,
    NoiseAugmentation,
    MultiWindowCutter,
    StackBatches,
    PermutationSelector,
)


In [2]:
experiment_root = Path('/netscratch2/kschuerholt/code/shrp/experiments/02_representation_learning/01_test/tune/ae_resnet_ffcv_permutation_test_1')
config_path = Path('/netscratch2/kschuerholt/code/shrp/experiments/02_representation_learning/01_test/tune/ae_resnet_ffcv_permutation_test_1/AE_trainable_b0ae4_00000_0_ae_d_model=512,ae_nhead=16,ae_num_layers=16,training_windowsize=1024_2023-05-12_16-54-42/params.json')
config = json.load(config_path.open('r'))

In [3]:
from ffcv.loader import Loader, OrderOption

# import new downstream module
from shrp.models.downstream_module_ffcv import DownstreamTaskLearner

# trainloader
batch_size = config["trainset::batchsize"]
num_workers = config.get("testloader::workers", 4)
ordering = OrderOption.QUASI_RANDOM
# Dataset ordering
path_trainset = str(config["dataset::dump"]) + ".train"
trainloader = Loader(
    path_trainset,
    batch_size=batch_size,
    num_workers=num_workers,
    order=ordering,
    drop_last=True,
    # pipelines=PIPELINES
    os_cache=False,
)
# trainloader
batch_size = config["trainset::batchsize"]
num_workers = config.get("testloader::workers", 4)
ordering = OrderOption.SEQUENTIAL
# Dataset ordering
path_testset = str(config["dataset::dump"]) + ".test"
testloader = Loader(
    path_testset,
    batch_size=batch_size,
    num_workers=num_workers,
    order=ordering,
    drop_last=True,
    # pipelines=PIPELINES
    os_cache=False,
)
# config
batch_size = config["trainset::batchsize"]
num_workers = config.get("testloader::workers", 4)
ordering = OrderOption.SEQUENTIAL
# Dataset ordering
path_valset = str(config["dataset::dump"]) + ".val"
valloader = Loader(
    path_valset,
    batch_size=batch_size,
    num_workers=num_workers,
    order=ordering,
    drop_last=True,
    # pipelines=PIPELINES
    os_cache=False,
)

In [4]:
device = 'cuda' if torch.cuda.is_available() else "cpu"

In [5]:
# test dataloaders
for idx, batch in enumerate(trainloader):
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    if idx>2:
        break
        
# test dataloaders
for idx, batch in enumerate(valloader):
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    if idx>2:
        break

# test dataloaders
for idx, batch in enumerate(testloader):
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    if idx>2:
        break

0 - [torch.Size([64, 11, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
1 - [torch.Size([64, 11, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
2 - [torch.Size([64, 11, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
3 - [torch.Size([64, 11, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
0 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
1 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
2 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
3 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
0 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 10

In [16]:
# test dataloaders
for idx, batch in enumerate(valloader):
    batch = (ddx.to(device) for ddx in batch)
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    if idx>2:
        break

# test dataloaders
for idx, batch in enumerate(testloader):
    batch = (ddx.to(device) for ddx in batch)
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    if idx>2:
        break

# test dataloaders
for idx, batch in enumerate(trainloader):
    batch = (ddx.to(device) for ddx in batch)
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    if idx>2:
        break


0 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
1 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
2 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
3 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
0 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
1 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
2 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
3 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
0 - [torch.Size([64, 11, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024,

In [7]:

stack_1 = []
windowsize = config.get("training::windowsize", 15)
if config.get("trainset::add_noise_view_1", 0.0) > 0.0:
    stack_1.append(
        NoiseAugmentation(config.get("trainset::add_noise_view_1", 0.0))
    )
if config.get("trainset::erase_augment", None) is not None:
    stack_1.append(ErasingAugmentation(**config["trainset::erase_augment"]))
stack_2 = []
if config.get("trainset::add_noise_view_2", 0.0) > 0.0:
    stack_2.append(
        NoiseAugmentation(config.get("trainset::add_noise_view_2", 0.0))
    )
if config.get("trainset::erase_augment", None) is not None:
    stack_2.append(ErasingAugmentation(**config["trainset::erase_augment"]))

stack_train = []
if config.get("trainset::multi_windows", None):
    stack_train.append(StackBatches())
else:
    stack_train.append(WindowCutter(windowsize=windowsize))
# put train stack together
if config.get("training::permutation_number", 0) == 0:
    split_mode = "copy"
    view_1_canon = True
    view_2_canon = True
else:
    split_mode = "permutation"
    view_1_canon = config.get("training::view_1_canon", True)
    view_2_canon = config.get("training::view_2_canon", False)
stack_train.append(
    TwoViewSplit(
        stack_1=stack_1,
        stack_2=stack_2,
        mode=split_mode,
        view_1_canon=view_1_canon,
        view_2_canon=view_2_canon,
    ),
)

trafo_train = AugmentationPipeline(stack=stack_train)

In [6]:
# trafo_train
# TwoViewSplit
split_mode

'permutation'

In [7]:
ddx = torch.randn([64, 11, 1024, 288])
mdx = torch.randn([64, 1024, 288])
pdx = torch.randn([64, 1024, 3])

print(ddx.shape, mdx.shape, pdx.shape)
ddx, mdx, pdx, ddx2, mdx2, pdx2 = trafo_train(ddx, mdx, pdx)
print(ddx.shape, mdx.shape, pdx.shape)
print(ddx2.shape, mdx2.shape, pdx2.shape)


torch.Size([64, 11, 1024, 288]) torch.Size([64, 1024, 288]) torch.Size([64, 1024, 3])
torch.Size([64, 1024, 288]) torch.Size([64, 1024, 288]) torch.Size([64, 1024, 3])
torch.Size([64, 1024, 288]) torch.Size([64, 1024, 288]) torch.Size([64, 1024, 3])


In [23]:
# test dataloaders
for idx, batch in enumerate(trainloader):
    print(type(batch))
#     batch = (ddx.to(device) for ddx in batch)
    batch = [sdx.to(device) for sdx in batch]
    print(type(batch))
    print(f'{idx} - {[bdx.shape for bdx in batch]}')
    batch2 = trafo_train(*batch)
    print(f'{idx} - {[bdx.shape for bdx in batch2]}')
#     if idx>2:
    break

<class 'tuple'>
<class 'list'>
0 - [torch.Size([64, 11, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
0 - [torch.Size([64, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3])]


In [9]:
ddx = torch.randn([64, 11, 1024, 288])
perm_ids = torch.randperm(n=ddx.shape[-3], dtype=torch.int32, device=ddx.device)[
            :3
        ]
ddx2 = torch.index_select(ddx.clone(), -3, perm_ids[1]).squeeze()
print(ddx.shape)
print(ddx2.shape)

torch.Size([64, 11, 1024, 288])
torch.Size([64, 1024, 288])


In [8]:
# test AUGMENTATIONS
stack_1 = []
if config.get("testset::add_noise_view_1", 0.0) > 0.0:
    stack_1.append(
        NoiseAugmentation(config.get("testset::add_noise_view_1", 0.0))
    )
if config.get("testset::erase_augment", None) is not None:
    stack_1.append(ErasingAugmentation(**config["testset::erase_augment"]))
stack_2 = []
if config.get("testset::add_noise_view_2", 0.0) > 0.0:
    stack_2.append(
        NoiseAugmentation(config.get("testset::add_noise_view_2", 0.0))
    )
if config.get("testset::erase_augment", None) is not None:
    stack_2.append(ErasingAugmentation(**config["testset::erase_augment"]))

stack_test = []
if config.get("trainset::multi_windows", None):
    stack_test.append(StackBatches())
else:
    stack_test.append(WindowCutter(windowsize=windowsize))
# put together
if config.get("testing::permutation_number", 0) == 0:
    split_mode = "copy"
    view_1_canon = True
    view_2_canon = True
else:
    split_mode = "permutation"
    view_1_canon = config.get("testing::view_1_canon", True)
    view_2_canon = config.get("testing::view_2_canon", False)
stack_test.append(
    TwoViewSplit(
        stack_1=stack_1,
        stack_2=stack_2,
        mode=split_mode,
        view_1_canon=view_1_canon,
        view_2_canon=view_2_canon,
    ),
)

# TODO: pass through permutation / view_1/2 canonical
trafo_test = AugmentationPipeline(stack=stack_test)

In [22]:
# test dataloaders
import copy
for idx, batch in enumerate(valloader):
    print(type(batch))
    print(f'{idx} - {[bdx.shape for bdx in batch]}')

    batch  = [sdx.to(device) for sdx in batch]
    print(type(batch))
    
    batch_print = copy.deepcopy(batch)
    print(f'{idx} - {[bdx.shape for bdx in batch_print]}')
    batch2 = trafo_test(*batch)
#     batch2 = trafo_train(*batch)
    print(f'{idx} - {[bdx.shape for bdx in batch2]}')
    if idx>2:
        break

<class 'tuple'>
0 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
<class 'list'>
0 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
0 - [torch.Size([64, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3])]
<class 'tuple'>
1 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
<class 'list'>
1 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
1 - [torch.Size([64, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3])]
<class 'tuple'>
2 - [torch.Size([64, 6, 1024, 288]), torch.Size([64, 1024, 288]), torch.Size([64, 1024, 3]), torch.Size([64, 3])]
<c

In [11]:
if config.get("training::permutation_number", 0) > 0:
    trafo_dst = PermutationSelector(mode="canonical")
else:
    trafo_dst = PermutationSelector(mode="identity")
