In [1]:
import pandas as pd
import random
from fed.utils.sampling import uniform_draw_subset
SEED = 42

random.seed(SEED)

for i in range(10):
    file_path = f'./clients/train/{i}.csv'
    save_path = f'./clients/probe/{i}.csv'
    uniform_draw_subset(file_path, 0.8, save_path)
    
# split probe & test file into each modality
for state in ['probe', 'test']:
    for i in range(10):
        file_path = f'./clients/{state}/{i}.csv'
        df = pd.read_csv(file_path, header=None)
        audio_df = df.iloc[:,0].to_csv(f'./clients/{state}/{i}_audio.csv', index=False, header=['audio']) 
        image_df = df.iloc[:,1].to_csv(f'./clients/{state}/{i}_image.csv', index=False, header=['image']) 


In [2]:
from typing import List, Tuple
from torch.utils.data import Dataset, DataLoader

from experiments.mmvae.mnist.dataset import imageMNIST, audioMNIST

NUM_CLIENT = 10
probe_loaders: List[List[DataLoader]] = []
test_loaders: List[List[DataLoader]] = []

dataloader_config = {
    'batch_size' : 32,
    'shuffle' : True
}

# construct probe & test audio and image loader for each client 
for state in ['probe', 'test']:
    for cid in range(NUM_CLIENT):
        client_loader = []
        for mod in ['audio', 'image']:
            file_path = f'./clients/{state}/{cid}_{mod}.csv'
            if mod == 'audio':
                dataset = audioMNIST(
                    csv_path=file_path 
                )
            else:
                dataset = imageMNIST(
                    csv_path=file_path
                )
            dataloader = DataLoader(dataset, **dataloader_config)
            client_loader.append(dataloader)
        if state == 'probe':
            probe_loaders.append(client_loader)
        else: # state == 'test'
            test_loaders.append(client_loader)
        

In [58]:
import torch
import torch.nn as nn
from reconstruct.mmvae import DecoupledMMVAE
from experiments.mmvae.mnist.model import _make_mlp
from experiments.mmvae.mnist.model import (
    get_mnist_audio_encoder,
    get_mnist_image_encoder
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load model
ckp_path = './iid_baseline.pt'
ckp = torch.load(ckp_path)

audio_extractor = get_mnist_audio_encoder().to(device)
image_extractor = get_mnist_image_encoder().to(device)

def make_mlp(inplanes, hidden_dim, out_dim, use_bn=False):
    if use_bn:
        return nn.Sequential(
            nn.Linear(inplanes, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    else:
        return nn.Sequential(
            nn.Linear(inplanes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    

mmvae_encoder = nn.ModuleDict({
        'audio' : make_mlp(64, 256, 128, use_bn=True),
        'image': make_mlp(64, 256, 128, use_bn=True)
    })
mmvae_decoder = model = nn.ModuleDict({
        'audio' : make_mlp(64, 256, 64, use_bn=True),
        'image':  make_mlp(64, 256, 64, use_bn=True)
    })
score_fns = nn.ModuleDict({
        'audio' : nn.MSELoss(),
        'image' : nn.MSELoss()
    })
mmvae = DecoupledMMVAE(
    encoders=mmvae_encoder,
    decoders=mmvae_decoder,
    latent_dim=64,
    score_fns=score_fns,
    device=device
)
# mmvae.load_state_dict(ckp['mmvae'])



In [59]:
audio_dl, image_dl = test_loaders[0]

torch.manual_seed(42)


audio_dl_iter = iter(audio_dl)
image_dl_iter = iter(image_dl)
batch_audio_1, audio_y_1 = next(audio_dl_iter)
batch_audio_2, audio_y_2 = next(audio_dl_iter)

batch_image_1, image_y_1 = next(image_dl_iter)
batch_image_2, image_y_2 = next(image_dl_iter)

# load models
audio_extractor.load_state_dict(ckp['audio'])
image_extractor.load_state_dict(ckp['image'])

with torch.no_grad():
    audio_extractor.eval()
    image_extractor.eval()
    
    audio_feature_1 = audio_extractor(batch_audio_1.to(device))
    audio_feature_2 = audio_extractor(batch_audio_2.to(device))
    
    image_feature_1 = image_extractor(batch_image_1.to(device))
    image_feature_2 = image_extractor(batch_image_2.to(device))
    

In [60]:
audio_y_1

tensor([1, 9, 1, 8, 9, 9, 1, 2, 1, 9, 7, 1, 1, 7, 4, 1, 1, 7, 1, 7, 7, 3, 3, 9,
        2, 1, 1, 0, 7, 1, 3, 2])

In [61]:
audio_y_2

tensor([2, 3, 7, 0, 1, 8, 3, 7, 1, 9, 1, 0, 1, 7, 1, 7, 8, 3, 1, 4, 0, 1, 3, 3,
        1, 7, 3, 3, 2, 9, 1, 9])

In [62]:
image_y_1

tensor([9, 9, 9, 1, 3, 1, 8, 7, 8, 5, 1, 9, 1, 1, 8, 9, 2, 1, 9, 7, 2, 2, 1, 7,
        7, 9, 1, 9, 1, 4, 9, 9])

In [67]:
audio = [0, 1, 2, 3]
image = [0, 1, 2, 3]

def pair_retrieval(x: torch.Tensor, y: torch.Tensor, normalize=True):
    # if normalize set to True, do normalize before compute dot product
    y_normalizer = torch.linalg.vector_norm(y, dim=-1, keepdim=True)
    normed_y = y / y_normalizer
    
    score_matrix = x @ normed_y.T
    query_res = torch.argmax(score_matrix, dim=-1)
    return query_res
    
keys = image_feature_2
pool = image_feature_1

keys_label = image_y_2
pool_label = image_y_1

query_res = pair_retrieval(keys, pool)
print(query_res)
for key, queried_idx in zip(keys_label, query_res):
    print(f'key: {key} -> query: {pool_label[queried_idx]}')

tensor([ 8, 23, 22, 24, 10,  1,  9, 16,  1, 12,  4,  7,  9,  5, 15, 11,  9,  9,
         7, 10, 23,  5, 11, 17,  8, 24, 23,  2, 15,  9,  9, 16],
       device='cuda:0')
key: 2 -> query: 8
key: 7 -> query: 7
key: 1 -> query: 1
key: 7 -> query: 7
key: 1 -> query: 1
key: 9 -> query: 9
key: 2 -> query: 5
key: 2 -> query: 2
key: 9 -> query: 9
key: 1 -> query: 1
key: 3 -> query: 3
key: 2 -> query: 7
key: 3 -> query: 5
key: 1 -> query: 1
key: 9 -> query: 9
key: 9 -> query: 9
key: 3 -> query: 5
key: 2 -> query: 5
key: 7 -> query: 7
key: 1 -> query: 1
key: 7 -> query: 7
key: 1 -> query: 1
key: 7 -> query: 9
key: 1 -> query: 1
key: 2 -> query: 8
key: 7 -> query: 7
key: 7 -> query: 7
key: 2 -> query: 9
key: 8 -> query: 9
key: 2 -> query: 5
key: 2 -> query: 5
key: 2 -> query: 2


In [26]:
image_feature[0,:5]

tensor([-0.0578,  0.0135,  0.0449,  0.0619, -0.0552], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [27]:
audio_feature[0,:5]

tensor([-0.0118, -0.0213, -0.0110, -0.0219, -0.0025], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [None]:
# linear probe on client's local dataset using server's model
# then record performance on test set
from utils.train import linear_prob, val
from experiments.mmvae.mnist.model import _make_mlp
import torch.optim as optim
import torch.nn as nn

res = []
NUM_CLIENTS = 10

def prepare_models(audio_model, image_model, ckp):
    audio_model.load_state_dict(ckp['audio'])
    image_model.load_state_dict(ckp['image'])
    return audio_model, image_model

for cid in range(NUM_CLIENTS):
    # probe
    probe_dls = probe_loaders[cid]
    test_dls = test_loaders[cid]
    
    local_audio_model, local_image_model = prepare_models(
        audio_extractor, image_extractor, ckp
    )

    
    to_probe = [local_audio_model, local_image_model]
    client_eva_res = []
    idx2modname = ['audio', 'image']
    for idx, model in enumerate(to_probe):
        model.eval()
        head = _make_mlp(64, 256, 10).to(device)
        optimizer = optim.Adam(head.parameters(), lr=1e-3, weight_decay=1e-5)
        criterion = nn.CrossEntropyLoss()
    
        linear_prob(
            model, head, 
            probe_dls[idx],
            optimizer, criterion,
            device,
            n_epoch=30,
            use_tqdm=False
        )
        
        # eval
        accuracy = val(model, test_dls[idx], device)
        print(f'client {cid} mod {idx2modname[idx]}: {100 * accuracy:.2f}%')
        client_eva_res.append(accuracy)
    res.append(client_eva_res)