In [1]:
from fed.server import Server
import flwr as fl
from torch.utils.tensorboard import SummaryWriter

from experiments.fed.dataset import (
    prepare_client_loader,
    get_default_aug
)
from experiments.mmvae.mnist.dataset import (
    audioMNIST, imageMNIST,
    mmMNIST
)
from torch.utils.data import DataLoader
import torch.nn as nn

from experiments.ssl.model import get_backbone
from experiments.ssl.dataset import get_mnist_transform
from fed.client import Client
from fed.config import STATE, LossMode

from pathlib import Path
import os

from experiments.mmvae.mnist.model import (
    get_mnist_audio_encoder,
    get_mnist_image_encoder
)
from model.loss import InfoNCE
import pandas as pd

import random

In [2]:
# static
def sample_client_state(num_client: int):
    states = [
        STATE.BOTH,
        STATE.BOTH,
        STATE.BOTH,
        STATE.BOTH,
        STATE.IMAGE,
        STATE.IMAGE,
        STATE.IMAGE,
        STATE.AUDIO,
        STATE.AUDIO,
        STATE.AUDIO
    ]
    random.shuffle(states)
    return states

def seed_everything(seed: int):
    import numpy as np
    import torch
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
SEED = 42
NUM_CLIENTS = 10
base_csv_path = './clients/probe'

seed_everything(SEED)

# collection of client configuration
clients_state = sample_client_state(NUM_CLIENTS)
clients_state

[<STATE.AUDIO: 0>,
 <STATE.BOTH: 2>,
 <STATE.BOTH: 2>,
 <STATE.AUDIO: 0>,
 <STATE.IMAGE: 1>,
 <STATE.IMAGE: 1>,
 <STATE.AUDIO: 0>,
 <STATE.IMAGE: 1>,
 <STATE.BOTH: 2>,
 <STATE.BOTH: 2>]

In [3]:
import pandas as pd
import random
from fed.utils.sampling import uniform_draw_subset
from typing import List, Tuple
from torch.utils.data import Dataset, DataLoader

from experiments.mmvae.mnist.dataset import imageMNIST, audioMNIST


# sample probing dataset
for i in range(10):
    file_path = f'./clients/train/{i}.csv'
    save_path = f'./clients/probe/{i}.csv'
    uniform_draw_subset(file_path, 0.2, save_path)
    
# split probe & test file into each modality
for state in ['probe', 'test']:
    for i in range(10):
        file_path = f'./clients/{state}/{i}.csv'
        df = pd.read_csv(file_path, header=None)
        audio_df = df.iloc[:,0].to_csv(f'./clients/{state}/{i}_audio.csv', index=False, header=['audio']) 
        image_df = df.iloc[:,1].to_csv(f'./clients/{state}/{i}_image.csv', index=False, header=['image']) 

# load dataloader for each modality


probe_loaders: List[List[DataLoader]] = []
test_loaders: List[List[DataLoader]] = []

dataloader_config = {
    'batch_size' : 32,
    'shuffle' : True
}

# construct probe & test audio and image loader for each client 
for state in ['probe', 'test']:
    for cid in range(NUM_CLIENTS):
        client_loader = []
        for mod in ['audio', 'image']:
            file_path = f'./clients/{state}/{cid}_{mod}.csv'
            if mod == 'audio':
                dataset = audioMNIST(
                    csv_path=file_path 
                )
            else:
                dataset = imageMNIST(
                    csv_path=file_path
                )
            dataloader = DataLoader(dataset, **dataloader_config)
            client_loader.append(dataloader)
        if state == 'probe':
            probe_loaders.append(client_loader)
        else: # state == 'test'
            test_loaders.append(client_loader)

In [4]:
from fed.server import LinearProbeServer
from fed.client import LinearProbeClient
from torch.utils.tensorboard import SummaryWriter

from experiments.mmvae.mnist.model import (
    get_mnist_audio_encoder,
    get_mnist_image_encoder,
    _make_mlp
)
from model.loss import InfoNCE
import torch

CKP_PATH = './ckp/backbone/272/PoE.pt'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

EVAL_MODAL = 'audio'

def load_pretrained_backbone(model, ckp_path):
    ckp = torch.load(ckp_path)
    state_dict = ckp[EVAL_MODAL]
    model.load_state_dict(state_dict)
    return model

if EVAL_MODAL == 'audio':
    dummy_state_dict = get_mnist_audio_encoder().state_dict()
else:
    dummy_state_dict = get_mnist_image_encoder().state_dict()

def setup_client(cid: str):
    cid = int(cid)
    audio_dl, image_dl = probe_loaders[cid]
    
    if EVAL_MODAL == 'audio':
        dataloader = audio_dl
        backbone = get_mnist_audio_encoder().to(DEVICE)
    else: # 'image'
        dataloader = image_dl
        backbone = get_mnist_image_encoder().to(DEVICE)
    
    # load pretrained backbone
    pretrained_backbone = load_pretrained_backbone(backbone, CKP_PATH)
    
    # test
    head = nn.Linear(64, 10)
    loss_fn = nn.CrossEntropyLoss()
    
    # linear probe client
    client = LinearProbeClient(
        cid,
        pretrained_backbone=pretrained_backbone, 
        head=head,
        loss_fn=loss_fn,
        trainloader=dataloader, 
        device=DEVICE,
    ).to(DEVICE)

    return client

In [5]:
import flwr as fl
from utils.scheduler import SineAnnealing
writer = SummaryWriter('/root/tf-logs/272/probe/PoE/audio')

ckp_save_path = './ckp/head/272/PoE_audio.pt'
optim_config = {
    'lr' : 1e-2,
    'weight_decay' : 1e-5
}

# baseline
strategy = LinearProbeServer(
    NUM_CLIENTS,
    embed_dim=64,
    num_class=10,
    local_epoch=5,
    optim_config=optim_config,
    tensorboard_writer=writer,
    save_ckp_interval=5,
    save_path=ckp_save_path
)

ray_init_args = {
    'include_dashboard' : True
}

fl.simulation.start_simulation(
    strategy=strategy,
    client_fn=setup_client,
    num_clients=NUM_CLIENTS,
    client_resources={'num_cpus':2, 'num_gpus':0.1, 'max_calls':1},
    config=fl.server.ServerConfig(num_rounds=60),
    ray_init_args=ray_init_args
)

INFO flwr 2023-07-12 21:44:41,269 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=60, round_timeout=None)
2023-07-12 21:44:43,633	INFO worker.py:1627 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m
INFO flwr 2023-07-12 21:44:44,415 | app.py:180 | Flower VCE: Ray initialized with resources: {'object_store_memory': 11386830028.0, 'memory': 22773660059.0, 'node:172.17.0.8': 1.0, 'accelerator_type:G': 1.0, 'GPU': 1.0, 'CPU': 16.0}
INFO flwr 2023-07-12 21:44:44,416 | server.py:86 | Initializing global parameters
INFO flwr 2023-07-12 21:44:44,420 | server.py:269 | Using initial parameters provided by strategy
INFO flwr 2023-07-12 21:44:44,421 | server.py:88 | Evaluating initial parameters
INFO flwr 2023-07-12 21:44:44,422 | server.py:101 | FL starting
DEBUG flwr 2023-07-12 21:44:44,423 | server.py:218 | fit_round 1: strategy sampled 10 clients (out of 10)
DEBUG flwr 2023-07-12 21:44:55,891 | server.py:232 | fit_round

