In [None]:
import os
from pathlib import Path
import time
import tqdm
import torch

In [None]:
torch.multiprocessing.set_start_method('spawn')

In [None]:
# for colab
!git clone https://github.com/laralex/Sk-DL2021-FinalProject
repo_dir = Path().absolute()/'Sk-DL2021-FinalProject'
%pushd Sk-DL2021-FinalProject
!git pull
!git checkout new_gen
!pip install pytorch_lightning
import sys
sys.path.append('Sk-DL2021-FinalProject')

In [None]:
# for local
# import sys
# sys.path.append('..')
# repo_dir = Path().absolute().parent

In [None]:
!pwd

import torch
from data.split_step_generator import SplitStepGenerator
from auxiliary.files import find_dataset_subdir

GOOGLE_DRIVE = True

if GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount(f'/content/drive')
    root_dir = Path('/content/drive/MyDrive/Sk-DL2021-Datasets')
else:
    root_dir = repo_dir.parent / 'generated_datasets'
if not os.path.exists(root_dir):
    os.makedirs(root_dir, exist_ok=True)
    
root_dir

In [None]:
import yaml
CONFIG_NAME = 'generate_dataset_nonlin0.05-0.5_compensate'
CONFIG = repo_dir/'configs'/f"{CONFIG_NAME}.yaml"

with open(CONFIG, 'r') as stream:
    config_hparams = yaml.safe_load(stream)['data']['init_args']
# config_hparams['dispersion_compensate'] = True
data_gen = SplitStepGenerator(**config_hparams)

In [None]:
import yaml
import datetime 

NEW_DIR_NAME = CONFIG_NAME

def create_destination(hparams, datasets_root, new_dir_name=None):
    if new_dir_name is None:
        new_dir = root_dir/datetime.datetime.now().strftime("%m-%d-%Y=%H-%M-%S")
    else:
        new_dir = root_dir/new_dir_name
    os.makedirs(new_dir, exist_ok=True)
    assert not os.path.exists(f'{new_dir}/signal_hparams.yaml')
    with open(f'{new_dir}/signal_hparams.yaml', 'w') as outfile:
        yaml.dump(hparams, outfile, default_flow_style=False)
    return new_dir
    
destination_root = find_dataset_subdir(data_gen.signal_hparams, root_dir)
if destination_root is None:
    destination_root = create_destination(data_gen.signal_hparams, root_dir, NEW_DIR_NAME)
print('Destination: ', destination_root) 

In [None]:
# make folders structure
def save_tensor(tensor, subdir):
    if tensor is None:
        print('Nothing to save', subdir)
        return
    if tensor.numel() == 0:
        return
    i = 0
    while os.path.exists(subdir/f"{i}.pt"):
        i += 1
    destination_path = subdir/f"{i}.pt"
    torch.save(torch.tensor([]), destination_path)
    torch.save(tensor.clone(), destination_path)
    
type_subdirs = [destination_root/sub for sub in ['train', 'val', 'test']]
for d in type_subdirs:
    os.makedirs(d, exist_ok=True)

TRAINING

In [None]:
import yaml

BATCH_SIZE = 20
GENERATE_TRAIN_BATCHES = 50
GENERATE_VAL_BATCHES = 0
GENERATE_TEST_BATCHES = 0

MIN_NONLIN = 0.02
MAX_NONLIN = 0.4

if os.path.exists(CONFIG):
    with open(CONFIG, 'r') as stream:
        train_hparams = yaml.safe_load(stream)['data']['init_args']
        train_hparams['batch_size'] = BATCH_SIZE
        train_hparams['generate_n_train_batches'] = GENERATE_TRAIN_BATCHES
        train_hparams['generate_n_val_batches'] = GENERATE_VAL_BATCHES
        train_hparams['generate_n_test_batches'] = GENERATE_TEST_BATCHES
        data_gen = SplitStepGenerator(**train_hparams)
        data_gen.prepare_data()
else:
    print('Config file cant be found')

In [None]:
train_hparams

In [None]:
# generate and save
loader = data_gen.train_dataloader()
loader.num_workers = 0
for inp, target in tqdm.tqdm(loader):
  if len(target.shape) == 4:
    target = target.squeeze(0)
    inp = inp.squeeze(0)
    assert len(target.shape) == 3 and len(inp.shape) == 3
  b = torch.stack([target, inp], dim=0)
  print(b.shape, b.sum())
  save_tensor(b, type_subdirs[0])

VALIDATION

In [None]:
import yaml

BATCH_SIZE = 100
GENERATE_TRAIN_BATCHES = 0
GENERATE_VAL_BATCHES = 3
GENERATE_TEST_BATCHES = 0

NONLINEARITY = 0.4

if os.path.exists(CONFIG):
    with open(CONFIG, 'r') as stream:
        val_hparams = yaml.safe_load(stream)['data']['init_args']
        val_hparams['batch_size'] = BATCH_SIZE
        val_hparams['generate_n_train_batches'] = GENERATE_TRAIN_BATCHES
        val_hparams['generate_n_val_batches'] = GENERATE_VAL_BATCHES
        val_hparams['generate_n_test_batches'] = GENERATE_TEST_BATCHES
        val_hparams['generation_nonlinearity_limits'] = None
        val_hparams['nonlinearity'] = NONLINEARITY
else:
    print('Config file cant be found')

In [None]:
val_hparams

In [None]:
data_gen = SplitStepGenerator(**val_hparams)
data_gen.prepare_data()

In [None]:
# generate and save
loader = data_gen.val_dataloader()
loader.num_workers = 0
for inp, target in tqdm.tqdm(loader):
  if len(target.shape) == 4:
    target = target.squeeze(0)
    inp = inp.squeeze(0)
    assert len(target.shape) == 3 and len(inp.shape) == 3
  b = torch.stack([target, inp])
  print(b.shape, b.sum())
  save_tensor(b, type_subdirs[1])

LOAD

In [None]:
config_hparams['data_source_type'] = 'filesystem'
config_hparams['load_dataset_root_path'] = root_dir
config_hparams['batch_size'] = 20
data_gen_load = SplitStepGenerator(**config_hparams)
config_hparams

In [None]:
data_gen_load.prepare_data()

In [None]:
loader = data_gen_load.train_dataloader()
loader.num_workers = 0 

In [None]:
import matplotlib.pyplot as plt
for idx, t in enumerate(loader):
    fig = plt.figure()
    plt.plot(t[0][2, :300, 5].real)
    plt.plot(t[1][2, :300, 5].real)
    plt.plot(t[0][2, :300, 5].imag)