In [1]:
from experiments.mmvae.mnist.model import (
    _make_mlp,
    get_mnist_audio_encoder,
    get_mnist_image_encoder
)
import torch
import torch.nn as nn

MODAL = 'audio'
SWITCH = 'hybrid'

HEAD_CKP_PATH = './ckp/head/272/PoE_audio.pt'
MAIN_CKP_PATH = './ckp/backbone/272/PoE.pt'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# load pretraiend models
head_ckp = torch.load(HEAD_CKP_PATH)
backbone_ckp = torch.load(MAIN_CKP_PATH)

head = nn.Linear(64, 10)
head.load_state_dict(head_ckp)

if MODAL == 'audio':
    backbone = get_mnist_audio_encoder()
else: # image
    backbone = get_mnist_image_encoder()
backbone.load_state_dict(backbone_ckp[MODAL])

<All keys matched successfully>

In [2]:
import torch.nn as nn
class Classifier(nn.Module):
    def __init__(self,
                 backbone,
                 head
                ):
        super().__init__()
        self.backbone, self.head = backbone, head
        
        # toggle feature extractor state
        self.backbone.eval()
    
    @torch.no_grad()
    def forward(self, x):
        feature = self.backbone(x)
        pred = self.head(feature)
        return pred

In [3]:
import pandas as pd
import os

NUM_CLIENTS = 10
test_path = './clients/test'

merged_dfs = []

for mod in ['audio', 'image']:
    dfs = []
    for cid in range(NUM_CLIENTS):
        csv_path = os.path.join(test_path, f'{cid}_{mod}.csv')
        df = pd.read_csv(csv_path)
        dfs.append(df)
    save_path = f'./{mod}_test_total.csv'
    merged_df = pd.concat(dfs, axis=0)
    merged_df.to_csv(save_path, index=False)


In [4]:
from experiments.mmvae.mnist.dataset import (
    audioMNIST, imageMNIST
)
from torch.utils.data import DataLoader

AUDIO_TEST_PATH = './audio_test_total.csv'
IMAGE_TEST_PATH = './image_test_total.csv'

# test audio clf
audio_dataset = audioMNIST(csv_path=AUDIO_TEST_PATH)
image_dataset = imageMNIST(csv_path=IMAGE_TEST_PATH)

dl_config = {
    'batch_size' : 128,
    'shuffle' : False
}
audio_dl = DataLoader(audio_dataset, **dl_config)
image_dl = DataLoader(image_dataset, **dl_config)

In [5]:
from utils.train import val
if MODAL == 'audio':
    dataloader = audio_dl
else:
    dataloader = image_dl

clf = Classifier(
    backbone,
    head
).to(DEVICE)
    
    
# test audio clf
acc = val(
    clf,
    dataloader,
    DEVICE
)
print('{} linear prob res: {:.2f}'.format(MODAL, acc.item()))

audio linear prob res: 0.47
