In [1]:
import torch
import torch.nn.functional as F

HEAD_CKP = None
BACKBONE_CKP = './ckp/components/99/debug_mm_only_iid.pt'
# BACKBONE_CKP = './extra/reconable_vanilla.pt'
NUM_CLIENT = 10
PROBE_PATH = './clients/probe'
TEST_PATH = './clients/test'
SEED = 272
# 43

device = 'cuda' if torch.cuda.is_available() else 'cpu'
%cd /root/fedRec

def seed_everything(seed: int):
    import numpy as np
    import torch
    import random
    
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
from fed.config import STATE
from typing import Callable, List, Any, Tuple
    
CLIENT_STATES = [
 STATE.AUDIO,
 STATE.BOTH,
 STATE.BOTH,
 STATE.AUDIO,
 STATE.IMAGE,
 STATE.IMAGE,
 STATE.AUDIO,
 STATE.IMAGE,
 STATE.BOTH,
 STATE.BOTH]

    
seed_everything(SEED)

/root/fedRec


In [2]:
import torch.nn as nn
from experiments.mmvae.mnist.model import (
    get_mnist_image_encoder,
    get_mnist_audio_encoder
)
from typing import List, Tuple
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from model.simclr import StandardPipeline


def get_proj_head():
    return nn.Linear(64, 128)


version = 'new'

def load_models(ckp_path):
    ckp = torch.load(ckp_path)
    
    if version == 'new':
        audio_model = get_mnist_audio_encoder().to(device)
        image_model = get_mnist_image_encoder().to(device)

        audio_model.load_state_dict(ckp['audio'])
        image_model.load_state_dict(ckp['image'])
        
        audio_model.eval()
        image_model.eval()

        return audio_model, image_model
    else: # version = 'old'
        audio_model = get_mnist_audio_encoder().to(device)
        image_model = get_mnist_image_encoder().to(device)
            

        audio_model.load_state_dict(ckp['audio'])
        image_model.load_state_dict(ckp['image'])

        return audio_model, image_model
        

import torchvision.transforms as T
from experiments.mmvae.mnist.dataset import imageMNIST, audioMNIST, mmMNIST

from experiments.ssl.dataset import get_mnist_transform

def load_dls(cid, probe_path, test_path, dl_config=None):
    if dl_config is None:
        dl_config = {
            'batch_size' : 64,
            'shuffle' : True
        }
    probe_dls = []
    
    
    
    audio_probe_set = audioMNIST(csv_path=f'{probe_path}/{cid}_audio.csv')
    
    _, transform = get_mnist_transform()
    transform = T.Compose(transform)
    image_probe_set = imageMNIST(csv_path=f'{probe_path}/{cid}_image.csv', 
                                 transform=transform)
    
    probe_dls = [
        DataLoader(audio_probe_set, **dl_config),
        DataLoader(image_probe_set, **dl_config)
    ]
    
    
    test_dls = []
    test_dl_config = {
        'batch_size' : 64,
        'shuffle' : False
    }
    
    audio_test_set = audioMNIST(csv_path=f'{test_path}/{cid}_audio.csv')
    _, transform = get_mnist_transform()
    transform = T.Compose(transform)
    
    image_test_set = imageMNIST(csv_path=f'{test_path}/{cid}_image.csv', transform=transform)
    test_dls = [
        DataLoader(audio_test_set, **test_dl_config),
        DataLoader(image_test_set, **test_dl_config)
    ]
    return probe_dls, test_dls

def fused_load_dls(cid, probe_path, test_path, dl_config=None):
    if dl_config is None:
        dl_config = {
            'batch_size' : 64,
            'shuffle' : True
        }
    
    client_state = CLIENT_STATES[cid]
    
    if client_state == STATE.AUDIO:
        probe_set = audioMNIST(csv_path=f'{probe_path}/{cid}_audio.csv')
        test_set = audioMNIST(csv_path=f'{test_path}/{cid}_audio.csv')
    elif client_state == STATE.IMAGE:
        _, transform = get_mnist_transform()
        transform = T.Compose(transform)
        probe_set = imageMNIST(csv_path=f'{probe_path}/{cid}_image.csv', 
                                 transform=transform)
        test_set = imageMNIST(csv_path=f'{test_path}/{cid}_image.csv', 
                                 transform=transform)
    else:
        _, transform = get_mnist_transform()
        transform = T.Compose(transform)
        probe_set = mmMNIST(csv_path=f'{probe_path}/{cid}.csv',
                            image_transform=transform,
                            with_label=True)
        test_set = mmMNIST(csv_path=f'{test_path}/{cid}.csv',
                           image_transform=transform,
                           with_label=True)
        
    
    
    probe_dl = DataLoader(probe_set, **dl_config)
    
    test_dl_config = {
        'batch_size' : 64,
        'shuffle' : False
    }
    test_dl = DataLoader(test_set, **test_dl_config)
    
    return probe_dl, test_dl


def pretty_print(cid, image_acc, audio_acc):
    state = CLIENT_STATES[cid]
    display = ("[CLIENT {}] -> {}  \n"
               "Audio Accuracy: {:.2f}\n"
               "Image Accuracy: {:.2f}"
              ).format(cid, state, image_acc, audio_acc)
    print(display)
    

    

In [3]:
from utils.train import val, linear_prob
import torch.nn as nn
from experiments.mmvae.mnist.model import _make_mlp

class StandardArch(nn.Module):
    def __init__(self,
                 backbone,
                 head,
                 normalize=False):
        super().__init__()
        self.backbone, self.head = backbone, head
        self._norm = normalize
    def forward(self, x):
        feature = self.backbone(x)
        if self._norm:
            feature = F.normalize(feature, dim=-1)
        pred = self.head(feature)
        return pred
    
def load_rec_net(path, key):
    # net = _make_mlp(128, 256, 128, use_bn=False).to(device)
    net = nn.Sequential(
        nn.Linear(64, 128),
        nn.LayerNorm(128),
        nn.ReLU(),
        nn.Linear(128, 64)
    ).to(device)
    ckp = torch.load(path)
    net.load_state_dict(ckp[key])
    return net
    
def fused_fwd(models, inputs,
              put_first=False,
              rec_net=None):
    embeds = []
    if rec_net is None:
        for model, x in zip(models, inputs):
            embeds.append(model(x))
    else:
        # do reconstruction
        possessed = models(inputs)
        generated = rec_net(possessed)
        if put_first:
            embeds = [possessed, generated]
        else:
            embeds = [generated, possessed]
    return torch.concat(embeds, dim=-1).detach()


probe_optim_config = {
    'lr' : 1e-1,
    'weight_decay' : 1e-5
}

# fusion version
def fused_run(recon: bool=False):
    res: List[Tuple[float, float]] = [] 
    # loop over clients
    for cid in range(NUM_CLIENT):
        audio_backbone, image_backbone = load_models(BACKBONE_CKP)
        
        audio_backbone.eval()
        image_backbone.eval()
        
        backbones = [audio_backbone, image_backbone]

        # prepare probing loader and test_loader
        probe_dl, test_dl = fused_load_dls(cid, PROBE_PATH, TEST_PATH)

        # if multi modal are present, use concatenate features as fused feature
        # otherwise, use single modal feature only
        if CLIENT_STATES[cid] == STATE.BOTH:
            head = nn.Linear(128, 10).to(device)
            # head = nn.Sequential(
            #     nn.Linear(128, 512),
            #     nn.ReLU(),
            #     nn.Linear(512, 10)
            # ).to(device)
            optimizer = optim.Adam(head.parameters(), lr=1e-1, weight_decay=1e-5)
            criterion = nn.CrossEntropyLoss()
            
            def unpack_and_forward(model, inputs):
                inputs = [t.to(device) for t in inputs]
                audio_x, image_x, y = inputs
                fused_embed = fused_fwd(model, [audio_x, image_x])
                pred = head(fused_embed)
                loss = criterion(pred, y)
                return loss
            extractor = backbones
        else:
            state = CLIENT_STATES[cid]
            if state == STATE.AUDIO:
                extractor = audio_backbone
                if recon:
                    constructor = load_rec_net(BACKBONE_CKP, 'a2i')
            else:
                extractor = image_backbone
                if recon:
                    constructor = load_rec_net(BACKBONE_CKP, 'i2a')
            if recon:
                extractor = [extractor, constructor]
                # head = nn.Sequential(
                #         nn.Linear(128, 512),
                #         nn.ReLU(),
                #         nn.Linear(512, 10)
                #     ).to(device)
                head = nn.Linear(128, 10).to(device)
                optimizer = optim.Adam(head.parameters(), lr=1e-1, weight_decay=1e-5)
            else:
                # head = nn.Sequential(
                #         nn.Linear(64, 512),
                #         nn.ReLU(),
                #         nn.Linear(512, 10)
                #     ).to(device)
                head = nn.Linear(64, 10).to(device)
                optimizer = optim.Adam(head.parameters(), lr=1e-1, weight_decay=1e-5)
            criterion = nn.CrossEntropyLoss()
            
            if recon:
                def rec_then_forward(model, inputs):
                    inputs = [t.to(device) for t in inputs]
                    encoder, constructor = model
                    x, y = inputs
                    put_first = (state == STATE.AUDIO)
                    fused_embed = fused_fwd(encoder, x, put_first, constructor)
                    pred = head(fused_embed)
                    loss = criterion(pred, y)
                    return loss
                unpack_and_forward = rec_then_forward
            else:
                unpack_and_forward = None

        # probe the head
        linear_prob(
            extractor, head,
            probe_dl,
            optimizer,
            criterion,
            device,
            35,
            use_tqdm=False,
            normalize=False,
            unpack_and_forward=unpack_and_forward
        )

        # eval on the test set
        if CLIENT_STATES[cid] == STATE.BOTH:
            test_model = [*extractor, head]
            def _test_unpack_and_fwd(model, inputs):
                head = model[-1]
                inputs = [t.to(device) for t in inputs]
                fused_feature = fused_fwd(model[:-1], inputs[:-1])
                pred = head(fused_feature)
                return pred
        else:
            if recon:
                test_model = [*extractor, head]
                def _test_unpack_and_fwd(model, inputs):
                    encoder, constructor, head = model
                    inputs = [t.to(device) for t in inputs]
                    x, y = inputs
                    put_first = (CLIENT_STATES[cid] == STATE.AUDIO)
                    fused_feature = fused_fwd(encoder, x, put_first, constructor)
                    pred = head(fused_feature)
                    return pred
            else:
                test_model = nn.Sequential(
                    extractor, head
                ).to(device)
                _test_unpack_and_fwd = None

        acc = val(test_model, test_dl, device, _test_unpack_and_fwd)

        
        res.append(acc)
    return res

In [4]:
from utils.train import val, linear_prob
import torch.nn as nn

class StandardArch(nn.Module):
    def __init__(self,
                 backbone,
                 head,
                 normalize=False):
        super().__init__()
        self.backbone, self.head = backbone, head
        self._norm = normalize
    def forward(self, x):
        feature = self.backbone(x)
        if self._norm:
            feature = F.normalize(feature, dim=-1)
        pred = self.head(feature)
        return pred
        


def run():
    res: List[Tuple[float, float]] = [] 
    # loop over clients
    for cid in range(NUM_CLIENT):
        audio_backbone, image_backbone = load_models(BACKBONE_CKP)
        backbones = [audio_backbone, image_backbone]

        # prepare probing loader and test_loader
        probe_dls, test_dls = load_dls(cid, PROBE_PATH, TEST_PATH)
        client_res = []

        # iteratively train a linear head and test it client-wise
        # uniformly draw probing set & test set
        for i in range(len(backbones)):
            extractor = backbones[i]
            head = nn.Linear(64, 10).to(device)
            optimizer = optim.Adam(head.parameters(), lr=1e-1, weight_decay=1e-5)
            criterion = nn.CrossEntropyLoss()
        
            probe_dl = probe_dls[i]
            test_dl = test_dls[i]

            # probe the head
            linear_prob(
                extractor, head,
                probe_dl,
                optimizer,
                criterion,
                device,
                35,
                use_tqdm=False,
                normalize=False
            )

            # reconfig
            extractor.eval()

            # eval on the test set
            model = StandardArch(extractor, head, normalize=False).to(device)
            acc = val(model, test_dl, device)
            client_res.append(acc)

        
        res.append(client_res)
    return res

In [5]:
def duplicate_runs(fn: Callable[[Any], List[float]],
                   times: int,
                   aggregator: Callable=None,
                   seeds: List[int]=None):
    collected_res = []

    # default aggreator is mean function
    if aggregator is None:
        def aggregator(res: List):
            return sum(res) / len(res)
        
    agg_res = []
    for idx in range(times):
        if seeds is not None:
            seed_everything(seeds[idx])
        one_time_res = fn()
        collected_res.append(one_time_res)

        # transpose two-dimensionaly list
        collected_res_t = [list(x) for x in zip(*collected_res)]
    
    for res_tuple in collected_res_t:
        agg_res.append(aggregator(res_tuple))
    
    return agg_res

def agg(t: List):
    res = []
    transposed_t = [list(x) for x in zip(*t)]
    for grouped_res in transposed_t:
        res.append(sum(grouped_res) / len(grouped_res))
    return res

# res = duplicate_runs(fused_run, 5, agg)
import numpy as np
multi_run = []
for _ in range(3):
    one_time_res = [i.cpu().numpy() for i in fused_run(recon=False)]
    multi_run.append(one_time_res)
res = np.array(multi_run).mean(axis=0)

In [6]:
# for cid, (audio_acc, image_acc) in enumerate(res):
#     pretty_print(cid, audio_acc, image_acc)
    
for cid, acc in enumerate(res):
    print('Client {}[{}]: {:.2f}'.format(cid, CLIENT_STATES[cid], acc))

Client 0[STATE.AUDIO]: 0.98
Client 1[STATE.BOTH]: 0.99
Client 2[STATE.BOTH]: 1.00
Client 3[STATE.AUDIO]: 0.99
Client 4[STATE.IMAGE]: 0.97
Client 5[STATE.IMAGE]: 0.98
Client 6[STATE.AUDIO]: 1.00
Client 7[STATE.IMAGE]: 0.97
Client 8[STATE.BOTH]: 0.99
Client 9[STATE.BOTH]: 0.99


In [7]:
# import multiprocessing as mp
# exec_fun = run
# def wrapped_run(res_collector):
#     res_collector.append(exec_fun())
    
# num_worker = 5
# manager = mp.Manager()
# res_container = manager.list()

# if __name__ == '__main__':
#     # mp.set_start_method('spawn')

#     workers = []
#     for _ in range(num_worker):
#         p = mp.Process(target=wrapped_run, args=(res_container, ))
#         p.start()
#         workers.append(p)
#     for p in workers:
#         p.join()

# def agg(t: List):
#     res = []
#     transposed_t = [list(x) for x in zip(*t)]
#     for grouped_res in transposed_t:
#         res.append(sum(grouped_res) / len(grouped_res))
#     return res

# agg_res = agg(res_container)

In [8]:
# for cid, (audio_acc, image_acc) in enumerate(res):
#     pretty_print(cid, audio_acc, image_acc)