In [5]:
# !pip install --upgrade pip
# !pip install musdb
# !pip install julius
# !pip install lameenc
# !pip install einops
# !pip install omegaconf
# !pip install diffq
# !pip install openunmix
# !pip install pytorch_lightning 
# !pip install torch --upgrade
# !pip install torchaudio 
# !pip install torchaudio --upgrade

In [6]:
from wav import get_datasets
import augment
import distrib
from demucs_p import demucs_phase

In [7]:
import torch
import torchaudio
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from torchmetrics.functional.audio import signal_distortion_ratio

In [33]:
train_set = get_datasets(data_type = 'train', metadata = './metadata')
test_set = get_datasets(data_type = 'test', metadata = './metadata1')

In [16]:
class Lit_phase(pl.LightningModule):
    def __init__(self, 
                train_set,
                valid_set,
                batch_size = 8,
                num_workers = 1):
        super().__init__()
        self.model = demucs_phase(sources = ['bass', 'acap', 'other', 'drums'], use_train_segment=False)
        self.criterion = nn.L1Loss()
        
        self.augment = [augment.Shift(shift=int(44100 * 1),
                                  same=True)]
        self.augment += [augment.FlipChannels(), augment.FlipSign()]
        self.augment = torch.nn.Sequential(*self.augment)
        
        self._train_set = train_set
        self._val_set = valid_set
        self.batch_size = batch_size
        self.num_workers = num_workers
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):

        source = batch
        source = self.augment(source)
        mix = source.sum(dim=1)
        
        source_predict = self.model(mix)
        
        loss = self.criterion(source_predict, source)
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
                      
        return loss
    
    def validation_step(self, batch, batch_idx):
        source = batch
        mix = source.sum(dim=1)
        
        source_predict = self.model(mix)
        
        loss = self.criterion(source_predict, source)
        self.log("valid_loss", loss, on_epoch=True, prog_bar=True)
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self._train_set, batch_size=self.batch_size, shuffle=True, drop_last=True, num_workers = self.num_workers)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self._val_set, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-4)
        return optimizer

In [17]:
early_stopping = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0001,
    patience=5
)

In [18]:
model_phase = Lit_phase(train_set = train_set, valid_set = test_set, num_workers = 15, batch_size = 16)

In [20]:
# trainer = pl.Trainer(accelerator='gpu',
#                      devices='auto',
#                     max_epochs = 1000,
#                     callbacks=[early_stopping]#,
#                     #strategy = 'ddp_notebook'
#                     )

In [15]:
# trainer.fit(model=model_phase)