In [4]:
from reconstruct.mmvae import DecoupledMMVAE
from experiments.mmvae.mnist.model import _make_mlp

import torch
import torch.nn as nn

device = 'cuda'

def make_mlp(inplanes, hidden_dim, out_dim, use_bn=False):
    if use_bn:
        return nn.Sequential(
            nn.Linear(inplanes, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    else:
        return nn.Sequential(
            nn.Linear(inplanes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    

encoders = nn.ModuleDict({
        'audio' : make_mlp(64, 32, 32, use_bn=True),
        'image': make_mlp(64, 32, 32, use_bn=True)
    }).to(device)
# encoders.eval()

decoders = nn.ModuleDict({
        'audio' : make_mlp(16, 32, 64, use_bn=True),
        'image': make_mlp(16, 32, 64, use_bn=True)
    }).to(device)
# decoders.eval()

score_fns = nn.ModuleDict({
        'audio' : nn.MSELoss(),
        'image' : nn.MSELoss()
    }).to(device)

mmvae = DecoupledMMVAE(
    encoders,
    decoders,
    16,
    score_fns,
    device
).to(device)
    
# load pretrained mmvae from last iter
ckp_path = './ckp/components/13/mmvae_moco_test_otherway_train.pt'
ckp = torch.load(ckp_path)

# mmvae.load_state_dict(ckp['mmvae']

In [5]:
# try other mmvae method
from experiments.mmvae.mnist.dataset import mmMNIST


public_dataset_path = '/root/autodl-tmp/csv/mmMNIST_server.csv'
public_dataset = mmMNIST(public_dataset_path)
len(public_dataset)

8000

In [8]:
from fed.utils.pipeline import PairedFeatureBank, MMVAETrainer
from experiments.mmvae.mnist.model import (
    get_mnist_audio_encoder,
    get_mnist_image_encoder
)

audio_backbone = get_mnist_audio_encoder().to(device)
audio_backbone.load_state_dict(ckp['audio'])

image_backbone = get_mnist_image_encoder().to(device)
image_backbone.load_state_dict(ckp['image'])

embed_dataset = PairedFeatureBank(
        public_dataset,
        (audio_backbone, image_backbone, ),
        device
)

In [9]:
dataloader_config = {
            'batch_size' : 32,
            'shuffle' : True,
        }
optim_config = {
    'lr' : 1e-2,
    'weight_decay' : 1e-5
}
trainer = MMVAETrainer(
    mmvae,
    embed_dataset,
    dataloader_config,
    optim_config,
    cross_loss_only=False
)
trainer.train()
fitted_mmvae = trainer.export_mmvae()

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/250 [00:00<?, ?it/s]

AttributeError: 'float' object has no attribute 'item'

In [None]:
# mmvae.load_state_dict(fitted_mmvae)

In [None]:
mmvae.load_state_dict(ckp['mmvae'])

In [None]:
mmvae.eval()
num_sample = 1000
res = mmvae.generate(num_sample=num_sample)
res['audio'].shape

In [10]:
import numpy as np
def _to_numpy(t):
    return t.contiguous().detach().cpu().numpy()

gen_audio_embed = _to_numpy(res['audio'])
gen_image_embed = _to_numpy(res['image'])


labels = [0 for _ in range(num_sample)] + [1 for _ in range(num_sample)]
labels = np.array(labels)

NameError: name 'res' is not defined

In [None]:
from sklearn.manifold import TSNE
import matplotlib.cm as cm
import matplotlib.pyplot as plt


feature_bank = np.concatenate([gen_audio_embed, gen_image_embed], axis=0)
embeds = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(feature_bank)
# colors = cm.rainbow(np.linspace(0, 1))

plt.figure(figsize=(10, 10))
names = ['gen_audio', 'gen_image']
for idx, name in enumerate(names):
    indices = np.where(labels == idx)
    plt.scatter(embeds[indices, 0], embeds[indices, 1], label=f'{name}')
plt.legend()
plt.show()

In [11]:
# load some real image/audio data
from experiments.mmvae.mnist.dataset import (
    audioMNIST, imageMNIST
)
from torch.utils.data import DataLoader

AUDIO_TEST_PATH = './audio_test_total.csv'
IMAGE_TEST_PATH = './image_test_total.csv'

# test audio clf
audio_dataset = audioMNIST(csv_path=AUDIO_TEST_PATH)
image_dataset = imageMNIST(csv_path=IMAGE_TEST_PATH)

dl_config = {
    'batch_size' : 1000,
    'shuffle' : False
}
audio_dl = DataLoader(audio_dataset, **dl_config)
image_dl = DataLoader(image_dataset, **dl_config)


In [None]:
from experiments.mmvae.mnist.model import (
    get_mnist_audio_encoder,
    get_mnist_image_encoder
)

audio_backbone = get_mnist_audio_encoder().to(device)
audio_backbone.load_state_dict(ckp['audio'])

image_backbone = get_mnist_image_encoder().to(device)
image_backbone.load_state_dict(ckp['image'])

In [None]:
# compute embeds
real_audio_embed = []
real_image_embed = []
audio_batch, _ = next(iter(audio_dl))
image_batch, _ = next(iter(image_dl))

real_audio_embed = _to_numpy(audio_backbone(audio_batch.to(device)))
real_image_embed = _to_numpy(image_backbone(image_batch.to(device)))

In [None]:
feature_bank = np.concatenate([real_audio_embed, real_image_embed], axis=0)

labels = []
for i in range(4):
    labels += [i for _ in range(num_sample)]
labels = np.array(labels)
    
embeds = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(feature_bank)

plt.figure(figsize=(10, 10))
names = ['real_audio', 'real_image']
for idx, name in enumerate(names):
    indices = np.where(labels == idx)
    plt.scatter(embeds[indices, 0], embeds[indices, 1], label=f'{name}')
plt.legend()
plt.show()

In [None]:
# try conditional generation
image_cond_audio = []
wrapped_inputs = {'audio' : torch.from_numpy(real_audio_embed).to(device)}
cond_gen = mmvae.reconstruct(wrapped_inputs)

image_cond_audio = _to_numpy(cond_gen['image'])
audio_cond_audio = _to_numpy(cond_gen['audio'])

feature_bank = np.concatenate([
    real_audio_embed,
    real_image_embed,
    audio_cond_audio,
    image_cond_audio,
    np.mean(real_audio_embed, axis=0, keepdims=True),
    np.mean(real_image_embed, axis=0, keepdims=True)

],
    axis=0)

labels = []
cate_num = 4
for i in range(cate_num):
    labels += [i for _ in range(num_sample)]
labels = np.array(labels + [cate_num, cate_num+1])
    
embeds = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(feature_bank)

plt.figure(figsize=(10, 10))
names = ['real_audio', 'real_image', 'rec_audio', 'rec_image', 'c_a', 'c_i']
for idx, name in enumerate(names):
    indices = np.where(labels == idx)
    plt.scatter(embeds[indices, 0], embeds[indices, 1], label=f'{name}')
plt.legend()
plt.show()

In [None]:
# try conditional generation
wrapped_inputs = {'image' : torch.from_numpy(real_image_embed).to(device)}
cond_gen = mmvae.reconstruct(wrapped_inputs)

image_cond_image = _to_numpy(cond_gen['image'])
audio_cond_image = _to_numpy(cond_gen['audio'])

feature_bank = np.concatenate([real_audio_embed, real_image_embed, audio_cond_image, image_cond_audio], axis=0)

labels = []
for i in range(4):
    labels += [i for _ in range(num_sample)]
labels = np.array(labels)
    
embeds = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(feature_bank)

plt.figure(figsize=(10, 10))
names = ['real_audio', 'real_image', 'rec_audio', 'rec_image']
for idx, name in enumerate(names):
    indices = np.where(labels == idx)
    plt.scatter(embeds[indices, 0], embeds[indices, 1], label=f'{name}')
plt.legend()
plt.show()