In [1]:
from argparse import Namespace
import numpy as np
import torch

from maml.datasets.esc50 import ESC50MetaDataset
from main import parse_args

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
is_training = False
args = parse_args([
    '--dataset', 'esc50',
    '--num-batches', '6000',
    '--output-folder', 'mmaml_5mode_5w1s',
    '--verbose',
    '--model-type', 'gated_conv_1d',
    '--embedding-type', 'ConvGRU1d',
    '--num-workers', '0',
    '--eval',
    '--checkpoint', 'train_dir/mmaml_5mode_5w1s/maml_gated_conv_1d_6000.pt',    
])

args.num_sample_embedding = min(args.num_sample_embedding, args.num_batches)

# computer embedding dims
num_gated_conv_layers = 4
if args.embedding_dims == 0:
    args.embedding_dims = []
    for i in range(num_gated_conv_layers):
        embedding_dim = args.num_channels*2**i
        if args.condition_type == 'affine':
            embedding_dim *= 2
        args.embedding_dims.append(embedding_dim)
assert not (args.mmaml_model and args.maml_model)

In [3]:
dataset = ESC50MetaDataset(
    root='data',
    audio_side_len=32,
    audio_channel=1,
    num_classes_per_batch=args.num_classes_per_batch,
    num_samples_per_class=args.num_samples_per_class,
    num_total_batches=args.num_batches,
    num_val_samples=args.num_val_samples,
    meta_batch_size=args.meta_batch_size,
    train=is_training,
    num_train_classes=args.num_train_classes,
    num_workers=args.num_workers,
    device=args.device)
classmap = dataset.classname
loss_func = torch.nn.CrossEntropyLoss()
collect_accuracies = True

In [4]:
from maml.models.conv_embedding_1d_model import ConvEmbeddingOneDimensionalModel
from maml.models.gated_conv_net_1d import GatedConv1dModel

model = GatedConv1dModel(
    input_channels=dataset.input_size[0],
    output_size=dataset.output_size,
    num_channels=args.num_channels,
    img_side_len=dataset.input_size[1],
    use_max_pool=args.use_max_pool,
    verbose=args.verbose)

embedding_model = ConvEmbeddingOneDimensionalModel(
        input_size=np.prod(dataset.input_size),
        output_size=dataset.output_size,
        embedding_dims=args.embedding_dims,
        hidden_size=args.embedding_hidden_size,
        num_layers=args.embedding_num_layers,
        convolutional=args.conv_embedding,
        num_conv=args.num_conv_embedding_layer,
        num_channels=args.num_channels,
        rnn_aggregation=(not args.no_rnn_aggregation),
        embedding_pooling=args.embedding_pooling,
        batch_norm=args.conv_embedding_batch_norm,
        avgpool_after_conv=args.conv_embedding_avgpool_after_conv,
        linear_before_rnn=args.linear_before_rnn,
        num_sample_embedding=args.num_sample_embedding,
        sample_embedding_file=args.sample_embedding_file+'.'+args.sample_embedding_file_type,
        img_size=dataset.input_size,
        verbose=args.verbose)
embedding_parameters = list(embedding_model.parameters())


if args.checkpoint != '':
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(args.device)
    if embedding_model: 
        embedding_model.load_state_dict(checkpoint['embedding_model_state_dict'])


In [5]:
from maml.metalearner import MetaLearner

optimizers = []
meta_learner = MetaLearner(
    model, embedding_model, optimizers, fast_lr=args.fast_lr,
    loss_func=loss_func, first_order=args.first_order,
    num_updates=args.num_updates,
    inner_loop_grad_clip=args.inner_loop_grad_clip,
    collect_accuracies=collect_accuracies, device=args.device,
    alternating=args.alternating, embedding_schedule=args.embedding_schedule,
    classifier_schedule=args.classifier_schedule, embedding_grad_clip=args.embedding_grad_clip)


In [6]:

def adapt(train_tasks):
    adapted_params = []
    embeddings_list = []
    # import pdb; pdb.set_trace();
    for task in train_tasks:
        params = model.param_dict
        embeddings = None
        if embedding_model:
            embeddings = embedding_model(task)
        for i in range(args.num_updates):
            preds = model(task, params=params, embeddings=embeddings)
            loss = loss_func(preds, task.y)
            params = update_params(loss, params=params)
        adapted_params.append(params)
        embeddings_list.append(embeddings)

    return adapted_params, embeddings_list


def step(adapted_params_list, embeddings_list, val_tasks):
    post_update_losses = []
    pred_list = []
    for adapted_params, embeddings, task in zip(adapted_params_list, embeddings_list, val_tasks):
        preds = model(task, params=adapted_params, embeddings=embeddings)
        pred_list.append(preds)
        loss = loss_func(preds, task.y)
        post_update_losses.append(loss)

    mean_loss = torch.mean(torch.stack(post_update_losses))
    return mean_loss, pred_list

def update_params(loss, params):
    """Apply one step of gradient descent on the loss function `loss`,
    with step-size `self._fast_lr`, and returns the updated parameters.
    """
    create_graph = not args.first_order
    grads = torch.autograd.grad(loss, params.values(),
                                create_graph=create_graph, allow_unused=True)
    for (name, param), grad in zip(params.items(), grads):
        if args.inner_loop_grad_clip > 0 and grad is not None:
            grad = grad.clamp(min=-args.inner_loop_grad_clip,
                                max=args.inner_loop_grad_clip)
        if grad is not None:
            params[name] = param - args.fast_lr * grad

    return params


In [7]:
import soundfile as sf

def tensor2wavfile(tensor, filename, sample_rate=44100):
    # Ensure tensor is on CPU
    if tensor.ndim > 1:
        raise ValueError(f'tensor with dim {tensor.ndim}')
    tensor = tensor.cpu().numpy()
    sf.write(filename, tensor, sample_rate,)


In [8]:
results = []
for i, (train_tasks, val_tasks) in enumerate(iter(dataset), start=1):
    # pre_train_measurements, adapted_params, embeddings = meta_learner.adapt(train_tasks) # train_tasks measurements before meta updates
    # post_val_measurements = meta_learner.step(adapted_params, embeddings, val_tasks, is_training) # val_tasks measurements after meta updates
    adapted_params, embeddings = adapt(train_tasks)
    mean_loss, pred_list = step(adapted_params, embeddings, val_tasks)
    results.append((train_tasks, val_tasks, pred_list))

    break

input size: torch.Size([5, 1, 80000])
conv1: torch.Size([5, 32, 16000])
bn1: torch.Size([5, 32, 16000])
relu1: torch.Size([5, 32, 16000])
conv2: torch.Size([5, 64, 3200])
bn2: torch.Size([5, 64, 3200])
relu2: torch.Size([5, 64, 3200])
conv3: torch.Size([5, 128, 640])
bn3: torch.Size([5, 128, 640])
relu3: torch.Size([5, 128, 640])
conv4: torch.Size([5, 256, 128])
bn4: torch.Size([5, 256, 128])
relu4: torch.Size([5, 256, 128])
reshape to: torch.Size([5, 256, 128])
reduce mean: torch.Size([5, 256])
fc: torch.Size([1, 128, 5])
reshape after avgpool: torch.Size([1, 128])
emb vec 1 size: torch.Size([1, 64])
emb vec 2 size: torch.Size([1, 128])
emb vec 3 size: torch.Size([1, 256])
emb vec 4 size: torch.Size([1, 512])
input size: torch.Size([5, 1, 80000])
layer1_conv: torch.Size([5, 32, 16000])
layer1_bn: torch.Size([5, 32, 16000])
layer1_condition: torch.Size([5, 32, 16000])
layer1_relu: torch.Size([5, 32, 16000])
layer2_conv: torch.Size([5, 64, 3200])
layer2_bn: torch.Size([5, 64, 3200])
lay

In [9]:
print(train_tasks[0].x.shape)
print(train_tasks[0].y.shape)
print(val_tasks[0].x.shape)
print(val_tasks[0].y.shape)
print(pred_list[0].shape)


torch.Size([5, 1, 80000])
torch.Size([5])
torch.Size([75, 1, 80000])
torch.Size([75])
torch.Size([75, 5])


In [14]:
import os
import os.path as osp
output_dir='test'
for result in results:
    train_tasks, val_tasks, pred_list = result
    for i, (train_task, val_task, preds) in enumerate(zip(*result)):
        # one task
        task_dir = osp.join(output_dir, f'task_{i}')
        train_dir = osp.join(task_dir, 'train')
        val_dir = osp.join(task_dir, 'val')
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(val_dir, exist_ok=True)
        preds = preds.argmax(dim=-1)
        # print(train_task.x.shape)
        # print(train_task.y.shape)
        # print(val_task.x.shape)
        # print(val_task.y.shape)
        # print(preds.shape)
        sample_rate = 16000
        for i, (audio_tensor, y, gt) in enumerate(zip(train_task.x, train_task.y,train_task.gt)):
            tensor2wavfile(audio_tensor[0], osp.join(train_dir, f'train_audio_gt{gt}_id{i}_category_{classmap[int(gt)]}.wav'), sample_rate=sample_rate)
        for i, (audio_tensor, y, pred, gt) in enumerate(zip(val_task.x, val_task.y, preds, val_task.gt)):
            tensor2wavfile(audio_tensor[0], osp.join(val_dir, f'val_audio_gt{gt}_pred{pred}_id{i}_category_{classmap[int(gt)]}.wav'), sample_rate=sample_rate)