In [None]:
import numpy as np
import torch

from maml.datasets.esc50 import ESC50MetaDataset
# from maml.datasets.cifar100_simulate_1d import Cifar100MetaDataset
from maml.datasets.cifar100 import Cifar100MetaDataset
from main import parse_args

In [None]:
is_training = False
args = parse_args([
    '--dataset', 'multimodal_few_shot',
    '--multimodal_few_shot', 'esc50', 'cifar',
    '--common-img-side-len', '32',
    '--common-img-channel', '3',
    '--num-batches', '6000',
    '--output-folder', 'mmaml_5mode_5w1s',
    '--verbose',
    '--model-type', 'gated_conv_1d',
    '--embedding-type', 'ConvGRU1d',
    '--num-workers', '0',
    '--eval',
    '--device', 'cuda:1',
    # '--checkpoint', 'train_dir/real_multimodal_2mode_5w1s/maml_gated_conv_1d_6000.pt',    
])

args.num_sample_embedding = min(args.num_sample_embedding, args.num_batches)

# computer embedding dims
num_gated_conv_layers = 4
if args.embedding_dims == 0:
    args.embedding_dims = []
    for i in range(num_gated_conv_layers):
        embedding_dim = args.num_channels*2**i
        if args.condition_type == 'affine':
            embedding_dim *= 2
        args.embedding_dims.append(embedding_dim)
assert not (args.mmaml_model and args.maml_model)

In [None]:
from maml.datasets.multimodal_few_shot import MultimodalFewShotDataset


dataset_list = []
dataset_list.append( ESC50MetaDataset(
    root='data',
    audio_side_len=32,
    audio_channel=1,
    num_classes_per_batch=args.num_classes_per_batch,
    num_samples_per_class=args.num_samples_per_class,
    num_total_batches=args.num_batches,
    num_val_samples=args.num_val_samples,
    meta_batch_size=args.meta_batch_size,
    train=is_training,
    num_train_classes=args.num_train_classes,
    num_workers=args.num_workers,
    device=args.device)
)   

dataset_list.append(Cifar100MetaDataset(
    root='data',
    img_side_len=args.common_img_side_len,
    img_channel=args.common_img_channel,
    num_classes_per_batch=args.num_classes_per_batch,
    num_samples_per_class=args.num_samples_per_class,
    num_total_batches=args.num_batches,
    num_val_samples=args.num_val_samples,
    meta_batch_size=args.meta_batch_size,
    train=is_training,
    num_train_classes=args.num_train_classes,
    num_workers=args.num_workers,
    device=args.device)
)

print('Multimodal Few Shot Datasets: {}'.format(
    ' '.join([dataset.name for dataset in dataset_list])))
dataset = MultimodalFewShotDataset(
    dataset_list, 
    num_total_batches=args.num_batches,
    mix_meta_batch=args.mix_meta_batch,
    mix_mini_batch=args.mix_mini_batch,
    txt_file=args.sample_embedding_file+'.txt' if args.num_sample_embedding > 0 else None,
    train=is_training,
)
loss_func = torch.nn.CrossEntropyLoss()
collect_accuracies = True

# MULTIMODAL DATASET

In [None]:
from maml.models.conv_embedding_1d_model import ConvEmbeddingOneDimensionalModel
from maml.models.gated_conv_net_1d import GatedConv1dModel

model = GatedConv1dModel(
    input_channels=dataset.input_size[0],
    output_size=dataset.output_size,
    num_channels=args.num_channels,
    img_side_len=dataset.input_size[1],
    use_max_pool=args.use_max_pool,
    verbose=args.verbose)

embedding_model = ConvEmbeddingOneDimensionalModel(
        input_size=np.prod(dataset.input_size),
        output_size=dataset.output_size,
        embedding_dims=args.embedding_dims,
        hidden_size=args.embedding_hidden_size,
        num_layers=args.embedding_num_layers,
        convolutional=args.conv_embedding,
        num_conv=args.num_conv_embedding_layer,
        num_channels=args.num_channels,
        rnn_aggregation=(not args.no_rnn_aggregation),
        embedding_pooling=args.embedding_pooling,
        batch_norm=args.conv_embedding_batch_norm,
        avgpool_after_conv=args.conv_embedding_avgpool_after_conv,
        linear_before_rnn=args.linear_before_rnn,
        num_sample_embedding=args.num_sample_embedding,
        sample_embedding_file=args.sample_embedding_file+'.'+args.sample_embedding_file_type,
        img_size=dataset.input_size,
        verbose=args.verbose)
embedding_parameters = list(embedding_model.parameters())


if args.checkpoint != '':
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(args.device)
    if embedding_model: 
        embedding_model.load_state_dict(checkpoint['embedding_model_state_dict'])


# META LEARNER

In [None]:
from maml.metalearner import MetaLearner

optimizers = []
meta_learner = MetaLearner(
    model, embedding_model, optimizers, fast_lr=args.fast_lr,
    loss_func=loss_func, first_order=args.first_order,
    num_updates=args.num_updates,
    inner_loop_grad_clip=args.inner_loop_grad_clip,
    collect_accuracies=collect_accuracies, device=args.device,
    alternating=args.alternating, embedding_schedule=args.embedding_schedule,
    classifier_schedule=args.classifier_schedule, embedding_grad_clip=args.embedding_grad_clip)


In [None]:

def adapt(train_tasks):
    adapted_params = []
    embeddings_list = []
    # import pdb; pdb.set_trace();
    for task in train_tasks:
        params = model.param_dict
        embeddings = None
        if embedding_model:
            embeddings = embedding_model(task)
        for i in range(args.num_updates):
            preds = model(task, params=params, embeddings=embeddings)
            loss = loss_func(preds, task.y)
            params = update_params(loss, params=params)
        adapted_params.append(params)
        embeddings_list.append(embeddings)

    return adapted_params, embeddings_list


def step(adapted_params_list, embeddings_list, val_tasks):
    post_update_losses = []
    pred_list = []
    for adapted_params, embeddings, task in zip(adapted_params_list, embeddings_list, val_tasks):
        preds = model(task, params=adapted_params, embeddings=embeddings)
        pred_list.append(preds)
        loss = loss_func(preds, task.y)
        post_update_losses.append(loss)

    mean_loss = torch.mean(torch.stack(post_update_losses))
    return mean_loss, pred_list

def update_params(loss, params):
    """Apply one step of gradient descent on the loss function `loss`,
    with step-size `self._fast_lr`, and returns the updated parameters.
    """
    create_graph = not args.first_order
    grads = torch.autograd.grad(loss, params.values(),
                                create_graph=create_graph, allow_unused=True)
    for (name, param), grad in zip(params.items(), grads):
        if args.inner_loop_grad_clip > 0 and grad is not None:
            grad = grad.clamp(min=-args.inner_loop_grad_clip,
                                max=args.inner_loop_grad_clip)
        if grad is not None:
            params[name] = param - args.fast_lr * grad

    return params


# SAVE FILES

In [None]:
import soundfile as sf
from torchvision.transforms.functional import to_pil_image
import PIL.Image
from matplotlib import pyplot as plt
def tensor2wavfile(tensor, filename, sample_rate=44100):
    # Ensure tensor is on CPU
    if tensor.ndim > 2:
        raise ValueError(f'tensor with dim {tensor.ndim}')
    elif tensor.ndim == 2:
        tensor = tensor.T # multi channel audio
    tensor = tensor.cpu().numpy()
    sf.write(filename + '.wav', tensor, sample_rate,)

def tensor2jpgfile(tensor, filename, output_shape=(32, 32)):
    # Ensure tensor is on CPU
    if tensor.ndim > 3:
        raise ValueError(f'tensor with dim {tensor.ndim}')
    img_pil = to_pil_image(tensor)
    img_pil = img_pil.resize(output_shape)
    img_pil.save(filename + '.jpg')
    
def task_tensor2file(taskname, tensor, filename):
    if taskname in ['CIFAR1001d', ]:
        h = int(np.sqrt(tensor.shape[-1] // 3))
        tensor = tensor.view(3, h, tensor.shape[-1] // 3 // h)
        tensor2jpgfile(tensor, filename)
    elif taskname in [ 'FC100']:
        tensor2jpgfile(tensor, filename)
    elif taskname in ['ESC50']:
        tensor2wavfile(tensor, filename, sample_rate=16000)
    else:
        raise ValueError(f'not valid task name {taskname}')

In [None]:
results = []
for i, (train_tasks, val_tasks) in enumerate(iter(dataset), start=1):
    # pre_train_measurements, adapted_params, embeddings = meta_learner.adapt(train_tasks) # train_tasks measurements before meta updates
    # post_val_measurements = meta_learner.step(adapted_params, embeddings, val_tasks, is_training) # val_tasks measurements after meta updates
    adapted_params, embeddings = adapt(train_tasks)
    mean_loss, pred_list = step(adapted_params, embeddings, val_tasks)
    results.append((train_tasks, val_tasks, pred_list))

    if i > 2: break

In [None]:
print(train_tasks[0].x.shape)
print(train_tasks[0].y.shape)
print(val_tasks[0].x.shape)
print(val_tasks[0].y.shape)
print(pred_list[0].shape)
print(type(results[0][0][0].task_info))
print(results[0][0][0].task_info)

# SAVE PREDICTIONS

In [None]:
import os
import os.path as osp

output_dir='test'
for j, result in enumerate(results):
    result_output_dir = osp.join(output_dir, f'result{j}')
    train_tasks, val_tasks, pred_list = result
    for i, (train_task, val_task, preds) in enumerate(zip(*result)):
        # one task
        print(train_task.task_info)
        task_dir = osp.join(result_output_dir, f'task_{i}')
        train_dir = osp.join(task_dir, 'train')
        val_dir = osp.join(task_dir, 'val')
        os.makedirs(train_dir, exist_ok=True)
        os.makedirs(val_dir, exist_ok=True)
        preds = preds.argmax(dim=-1)
        # print(train_task.x.shape)
        # print(train_task.y.shape)
        # print(val_task.x.shape)
        # print(val_task.y.shape)
        # print(preds.shape)
        sample_rate = 16000
        gt2class = {}
        for i, (tensor, y, gt) in enumerate(zip(train_task.x, train_task.y, train_task.gt)):
            class_name = dataset.get_class_name(train_task.task_info, int(gt))
            task_tensor2file(train_task.task_info, tensor, osp.join(train_dir, f'train_{train_task.task_info}_gt{y}_category_{class_name}_id{i}'))
            gt2class[y.tolist()] = class_name
        for i, (tensor, y, pred) in enumerate(zip(val_task.x, val_task.y, preds)):
            class_name = gt2class[pred.tolist()]
            task_tensor2file(val_task.task_info, tensor, osp.join(val_dir, f'val_{val_task.task_info}_gt{y}_pred{pred}_category_{class_name}_id{i}'),)

In [None]:
temp = torch.clip(tensor[0] * 255, 0, 255)
# temp = temp.to(dtype=torch.uint8)
# temp = temp.reshape(3, 163, 163).cpu().numpy()[::-1, ...].transpose(1, 2, 0)
# PIL.Image.fromarray(temp)

In [None]:
train_task.x.shape