In [1]:
from utils import augment
from utils.augment import *

from trainer import *

import jax.numpy as jnp
import jax
import flax.linen as nn
import numpy as np
import random
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
from utils.dataloader import mel_dataset

wandb.init(
project='CLR',
entity='aiffelthon'
)

data = mel_dataset('/mnt/disks/sdb/dev_dataset', 'total')

mix = MixupBYOLA()
crop = RandomResizeCrop()


def collate_batch(batch):
    x_train_1 = []
    x_train_2 = []
    
    for x, y in batch:
        x = (np.array(x)+127)/100
        x = np.expand_dims(x, axis=-1)
        x = crop(mix(x))        
        x_train_1.append(x)
        
    for x, y in batch:
        x = (np.array(x)+127)/100
        x = np.expand_dims(x, axis=-1)
        x = crop(mix(x))        
        x_train_2.append(x)
            
    y_train = [y for _, y in batch]           
    return augment.post_norm(np.stack(x_train_1 + x_train_2, axis=0)), np.array(y_train)

def eval_collate_batch(batch):
    x_train = [(np.array(x)+127)/100 for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)


def train_simclr(num_epochs, **kwargs):
    # Create a trainer module with specified hyperparameters
    trainer = BYOL_ATrainer(exmp=jnp.ones((128,48,1876,1)))
    trainer.train_model(train_dataloader, test_dataloader, num_epochs=num_epochs)

    return trainer

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
        
def prepare_data_features(encode_fn, data):
    # Encode all images
    dataset = DataLoader(data, batch_size=128,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=0,
                                  collate_fn=eval_collate_batch)
    
    feats, labels = [], []
    for batch_imgs, batch_labels in tqdm(dataset):
        batch_feats = encode_fn(batch_imgs)
        feats.append(jax.device_get(batch_feats))
        labels.append(batch_labels)

    feats = np.concatenate(feats, axis=0)
    labels = np.concatenate(labels, axis=0)

    # Sort images by labels for easier postprocessing later
    idxs = labels.argsort()
    labels, feats = labels[idxs], feats[idxs]

    return NumpyDataset(feats, labels)
    
    
class NumpyDataset(mel_dataset):
    # data.TensorDataset for numpy arrays

    def __init__(self, *arrays):
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, idx):
        return [arr[idx] for arr in self.arrays]

    
def train_logreg(batch_size, train_feats_data, test_feats_data, num_epochs=100, **kwargs):
    # Data loaders
    train_loader = DataLoader(train_feats_data,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   drop_last=True,
                                   generator=torch.Generator().manual_seed(42),
                                   collate_fn=numpy_collate)
    test_loader = DataLoader(test_feats_data,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  drop_last=False,
                                  collate_fn=numpy_collate)

    # Create a trainer module with specified hyperparameters
    trainer = LGTrainer(exmp=next(iter(train_loader))[0],
                        # model_suffix=model_suffix,
                        **kwargs)
    trainer.train_model(train_loader, test_loader, num_epochs=num_epochs)

    # Test best model on train and validation set
    train_result = trainer.eval_model(train_loader) 
    test_result = trainer.eval_model(test_loader)

    return trainer    
    
dataset_size = len(data)
train_size = int(dataset_size * 0.8)    
test_size = dataset_size - train_size

train_dataset, test_dataset, = random_split(data, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=0, collate_fn=collate_batch)
simclr_trainer = train_simclr(num_epochs=10)

    
#     simclr_model = simclr_trainer.model.bind({'params': simclr_trainer.state.params,
#                                           'batch_stats': simclr_trainer.state.batch_stats},
#                                         mutable=['batch_stats'])
    
#     encode_fn = jax.jit(lambda img: simclr_model.encode(img))
#     train_feats_simclr = prepare_data_features(encode_fn, train_dataset)
#     test_feats_simclr = prepare_data_features(encode_fn, test_dataset)
    
#     trainer = train_logreg(batch_size=128,
#                          train_feats_data=train_feats_simclr,
#                          test_feats_data=test_feats_simclr,
#                          num_classes=30,
#                          lr=1e-3,
#                          weight_decay=1e-3)

  from .autonotebook import tqdm as notebook_tqdm
2022-09-20 05:11:20.456163: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-09-20 05:11:20.484030: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-09-20 05:11:21.121079: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2022-09-20 05:11:21.121217: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such fi

Load song_meta.json...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 707989/707989 [00:00<00:00, 758564.63it/s]


Load complete!

Load file list...


4it [00:00, 23.12it/s]
wandb: Waiting for W&B process to finish... (success).                                                                                                                                                           
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:  eval_loss █▂▁▁▁▁▁▁▁▁
wandb: train_loss █▂▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb:  eval_loss 0.00035
wandb: train_loss 0.00043
wandb: 
wandb: Synced giddy-flower-135: https://wandb.ai/aiffelthon/CLR/runs/2vdz19ay
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20220920_051123-2vdz19ay/logs
https://symbolize.stripped_domain/r/?trace=7f7d40b9ac7f,7f7d40afb08f&map= 
*** SIGTERM received by PID 363478 (TID 363478) on cpu 89 from PID 356028; stack trace: ***
PC: @     0x7f7d40b9ac7f  (unknown)  wait4
    @     0x7f7af16e7294        976  (unknown)
    @     0x7f7d40afb09