In [None]:
from utils import augment
from utils.augment import *

import jax.numpy as jnp
import jax
import flax.linen as nn
import numpy as np
import random
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch
from utils.dataloader import mel_dataset

In [None]:
data = mel_dataset('/mnt/disks/sdb/dataset', 'total')

In [None]:
mix = MixupBYOLA()
crop = RandomResizeCrop()

In [None]:
def collate_batch(batch):
    x_train_1 = []
    x_train_2 = []
    
    for x, y in batch:
        x = (np.array(x)+127)/100
        x = np.expand_dims(x, axis=-1)
        x = crop(mix(x))        
        x_train_1.append(x)
        
    for x, y in batch:
        x = (np.array(x)+127)/100
        x = np.expand_dims(x, axis=-1)
        x = crop(mix(x))        
        x_train_2.append(x)
            
    y_train = [y for _, y in batch]           
    return augment.post_norm(np.stack(x_train_1 + x_train_2, axis=0)), np.array(y_train)

In [None]:
dataset_size = len(data)
train_size = int(dataset_size * 0.8)    
test_size = dataset_size - train_size

train_dataset, test_dataset, = random_split(data, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=0, collate_fn=collate_batch)

In [None]:
@partial(jax.jit, static_argnames=['k'])
def top_k(loss_list, ts, k):
    top_k = jnp.argsort(loss_list, axis=1)[:k]
    correct = 0
    for i in range(ts.shape[0]):
        b = (jnp.where(top_k[i,:] == ts[i], jnp.ones((top_k[i,:].shape)), 0)).sum()
        correct += b
    correct /= ts.shape[0]
    return correct

In [1]:
def train_simclr(num_epochs, **kwargs):
    # Create a trainer module with specified hyperparameters
    trainer = AudioEncoderTrainer(exmp=jnp.ones((16,48,1876,1)))
    trainer.train_model(train_dataloader, test_dataloader, num_epochs=1)

    return trainer

In [None]:
simclr_trainer = train_simclr(10)

In [3]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

In [4]:
class NumpyDataset(mel_dataset):
    # data.TensorDataset for numpy arrays

    def __init__(self, *arrays):
        self.arrays = arrays

    def __len__(self):
        return self.arrays[0].shape[0]

    def __getitem__(self, idx):
        return [arr[idx] for arr in self.arrays]

NameError: name 'mel_dataset' is not defined

In [5]:
def prepare_data_features(encode_fn, data):
    # Encode all images
    dataset = DataLoader(data, batch_size=16,
                                  shuffle=True,
                                  drop_last=True,
                                  num_workers=0,
                                  collate_fn=collate_batch)
    
    feats, labels = [], []
    for batch_imgs, batch_labels in tqdm(dataset):
        batch_feats = encode_fn(batch_imgs)
        feats.append(jax.device_get(batch_feats))
        labels.append(batch_labels)

    feats = np.concatenate(feats, axis=0)
    labels = np.concatenate(labels, axis=0)

    # Sort images by labels for easier postprocessing later
    idxs = labels.argsort()
    labels, feats = labels[idxs], feats[idxs]

    return NumpyDataset(feats, labels)

In [None]:
train_feats_simclr = prepare_data_features(encode_fn, train_dataset)
test_feats_simclr = prepare_data_features(encode_fn, test_dataset)

In [6]:
simclr_model = simclr_trainer.model.bind({'params': simclr_trainer.state.params,
                                          'batch_stats': simclr_trainer.state.batch_stats},
                                        mutable=['batch_stats'])

NameError: name 'simclr_trainer' is not defined

In [None]:
encode_fn = jax.jit(lambda img: simclr_model.encode(img))

In [7]:
def train_logreg(batch_size, train_feats_data, test_feats_data, num_epochs=100, **kwargs):
    # Data loaders
    train_loader = DataLoader(train_feats_data,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   drop_last=True,
                                   generator=torch.Generator().manual_seed(42),
                                   collate_fn=numpy_collate)
    test_loader = DataLoader(test_feats_data,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  drop_last=False,
                                  collate_fn=numpy_collate)

    # Create a trainer module with specified hyperparameters
    trainer = LGTrainer(exmp=next(iter(train_loader))[0],
                        # model_suffix=model_suffix,
                        **kwargs)
    trainer.train_model(train_loader, test_loader, num_epochs=num_epochs)

    # Test best model on train and validation set
    train_result = trainer.eval_model(train_loader) 
    test_result = trainer.eval_model(test_loader)
    result = {"train": train_result["acc"], "test": test_result["acc"]}

    return trainer, result

In [8]:
results = {}

_, result = train_logreg(batch_size=16,
                         train_feats_data=train_feats_simclr,
                         test_feats_data=test_feats_simclr,
                         num_classes=30,
                         lr=1e-3,
                         weight_decay=1e-3
                         )

NameError: name 'train_feats_simclr' is not defined