In [5]:
import torchaudio as ta
from torch.utils.data import Dataset

import os
import torch.nn.functional as F
import torch.nn as nn
import torch

from pl_model import PL_model
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

# data

In [6]:
class my_dataset(Dataset):
    def __init__(self, path):
        self.path_signal = os.listdir(f'{path}/mic')
        self.path_target = f'{path}/target'
        self.path_mic = f'{path}/mic'
        
    def __len__(self):
        return len(self.path_signal)
    
    def __getitem__(self, idx):
        path = self.path_signal[idx]
        
        target, sample_rate = ta.load(f'{self.path_target}/{path}')
        mic_array, sample_rate = ta.load(f'{self.path_mic}/{path}')
        
        target = F.pad(target, (0, mic_array.shape[-1] - target.shape[-1]))
        return target, mic_array

# train

In [7]:
pl_m = PL_model(4)

In [8]:
lr_monitor = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(
        monitor="valid_loss",
        mode='min',
        save_top_k=1,
    )

trainer = pl.Trainer(
        accelerator="cpu",
        devices="auto",
        max_epochs=1000,
        accumulate_grad_batches=32, 
        callbacks=[checkpoint_callback, lr_monitor]
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
data_train = my_dataset('./beam_data/train')
data_valid = my_dataset('./beam_data/valid')

In [11]:
train_dl = torch.utils.data.DataLoader(
        data_train,
        batch_size=1,
        shuffle=True,
        num_workers=1,
    )

valid_dl = torch.utils.data.DataLoader(
        data_valid,
        batch_size=1,
        shuffle=False,
        num_workers=1,
    )

In [13]:
# trainer.fit(pl_m, train_dl, valid_dl)