In [None]:
% load_ext autoreload
% autoreload 2

In [None]:

import os
import pickle

from modules.logging.format_utils import format_measures
from modules.collecting.results_collector import DataFrameCollector
from modules.logging.logger import DefaultLogger
from modules.algorithms.base.OSLPP import Params
from modules.selection.uncertanties import SelectRejectMode
import torch.nn.functional as F

In [None]:
with open('experiments/configs/small_datasets.pkl', 'rb') as f:
    config = pickle.load(f)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [None]:
from dataclasses import dataclass


@dataclass
class MlInfoParams:
    oslpp_params: Params
    tau: float
    lr: float
    batch_size: int
    num_layers: int
    pairwise_distance_fn: str  # Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
    pos: int
    neg: int
    epochs: int

@dataclass
class MlInfonceParams:
    oslpp_params: Params
    tau: float
    lr: float
    batch_size: int
    num_layers: int
    pairwise_distance_fn: str # Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
    pos: int
    neg: int
    epochs: int

In [None]:
from modules.algorithms.base.OSLPP import get_l2_normalized
from modules.algorithms.nn.OSLPP_NN_UTILS import predict_NN
import math


def create_model(params: MlInfonceParams):
    if params.num_layers == 1:
        return nn.Linear(params.oslpp_params.pca_dim, params.oslpp_params.proj_dim)
    else:
        mid_dim = int(math.sqrt(params.oslpp_params.pca_dim * params.oslpp_params.proj_dim))
        layers = [nn.Linear(params.oslpp_params.pca_dim, mid_dim)]
        for _ in range(1, (params.num_layers - 1)):
            layers += [nn.ReLU(), nn.Linear(mid_dim, mid_dim)]
        layers += [nn.ReLU(), nn.Linear(mid_dim, params.oslpp_params.proj_dim)]
        return nn.Sequential(*layers)


def predict(model, features, labels):
    model = model.eval()
    preds = predict_NN(model, features, labels)
    return get_l2_normalized(preds)  # l2 norm

In [None]:
from modules.loaders.balanced_sampling import create_train_dataloader
from modules.metric_learning.ml import TripletLoss, get_euclidian_distances, \
    get_cosine_distances, TripletDataset, get_pairwise_euclidian_distances, get_pairwise_cosine_distances, InfoNCELoss, \
    InfoNCEDataset


def get_distance_fn(name):
    if name == 'eucl': return get_euclidian_distances
    elif name == 'cos': return get_cosine_distances
    else: raise Exception(f'Unsupported distance function name {name}')

def get_distance_fn_pairwise(name):
    if name == 'eucl': return get_pairwise_euclidian_distances
    elif name == 'cos': return get_pairwise_cosine_distances
    else: raise Exception(f'Unsupported pairwise distance function name {name}')

def train_model(features, labels, params: MlInfonceParams):
    # train model for N classes (remove -1 and -2 and num_src_classes labels)
    num_src_classes = params.oslpp_params.num_common + params.oslpp_params.num_src_priv
    features, labels = features[labels >= 0], labels[labels >= 0]
    # feats, lbls = feats[lbls < num_src_classes], lbls[lbls < num_src_classes]

    model = create_model(params).cuda().train()
    optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
    loss_fn = InfoNCELoss(params.tau)

    ds = InfoNCEDataset(features, labels, params.pos, params.neg)
    dl = create_train_dataloader(ds, batch_size=params.batch_size, balanced=True)
    for ep in range(params.epochs):
        for batch in dl:
            anchor, pos, neg = batch['anchor'][0], batch['pos'][0], batch['neg'][0]
            b_pos, num_pos, d_pos = pos.shape
            b_neg, num_neg, d_neg = neg.shape
            anchor, pos, neg = model(anchor.cuda()), \
                model(pos.cuda().view(-1, d_pos)).view(b_pos, num_pos, -1), \
                model(neg.cuda().view(-1, d_neg)).view(b_neg, num_neg, -1)
            loss = loss_fn(anchor, pos, neg)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    return model.eval()

In [None]:
from modules import metric_learning


def get_centroids(predictions, labels):
    centroids = metric_learning.ml.get_centroids(predictions, labels)
    return metric_learning.ml.get_l2_normalized(centroids)  # l2 norm

In [None]:
from modules.selection.uncertanties import select_initial_rejected, get_new_rejected
import torch
import numpy as np
from modules.algorithms.base.OSLPP import do_l2_normalization, do_pca, select_closed_set_pseudo_labels, evaluate_T
from modules.loaders.osda import create_datasets_sub
from modules.loaders.common import set_seed
from torch import nn


def train(params: MlInfonceParams,
          select_reject_mode, seed,
          common, tgt_private, logger):
    set_seed(seed)
    print(params.oslpp_params.source, '->', params.oslpp_params.target, 'lr=', params.lr, 'seed=', seed)

    (feats_S, labels_S), (feats_T, labels_T) = create_datasets_sub(params.oslpp_params.dataset,
                                                                   params.oslpp_params.source,
                                                                   params.oslpp_params.target,
                                                                   common,
                                                                   tgt_private, None)

    params.oslpp_params.n_r = int(len(labels_T) * params.oslpp_params.n_r)
    num_src_classes = params.oslpp_params.num_common + params.oslpp_params.num_src_priv

    feats_S, feats_T = do_l2_normalization(feats_S, feats_T)
    feats_S, feats_T = do_pca(feats_S, feats_T, params.oslpp_params.pca_dim)
    feats_S, feats_T = do_l2_normalization(feats_S, feats_T)

    rejected = np.zeros((len(feats_T),), dtype=np.int)
    pseudo_labels = -torch.tensor(np.ones((len(feats_T),), dtype=np.int))
    cs_pseudo_labels = pseudo_labels

    # initial
    feats_S, feats_T = torch.tensor(feats_S), torch.tensor(feats_T)
    labels_S, labels_T = torch.tensor(labels_S), torch.tensor(labels_T)
    feats_all = torch.cat((feats_S, feats_T), dim=0)


    pairwise_distance_fn = get_distance_fn_pairwise(params.pairwise_distance_fn)

    for t in range(1, params.oslpp_params.T):
        model = train_model(feats_all, torch.cat([labels_S, pseudo_labels], dim=0), params)

        centroids = get_centroids(predict(model, feats_S, labels_S), labels_S)
        preds = predict(model, feats_T, labels_T)

        pairwise_distances = pairwise_distance_fn(preds, centroids)
        pairwise_distances = F.softmax(-pairwise_distances, dim=-1)

        cs_pseudo_probs, cs_pseudo_labels = pairwise_distances.max(dim=-1)
        selected = torch.tensor(
            select_closed_set_pseudo_labels(cs_pseudo_labels.numpy(), preds.unsqueeze(0),
                                            t, params.oslpp_params.T,
                                            mode=select_reject_mode,
                                            uniform_ratio=None, balanced=False,
                                            weights=None, tops=None, aug_preds=None))

        if t == 1:
            rejected = torch.tensor(select_initial_rejected(preds.unsqueeze(0), params.oslpp_params.n_r,
                                                            mode=select_reject_mode, weights=None,
                                                            tops=None, aug_preds=None))
        else:
            selected = selected * (1 - rejected)
            rejected_new = get_new_rejected(preds.unsqueeze(0), selected, rejected, mode=select_reject_mode,
                                            weights=None, tops=None, aug_preds=None)
            rejected[rejected_new == 1] = 1
        selected = selected * (1 - rejected)

        pseudo_labels = cs_pseudo_labels.clone()
        pseudo_labels[rejected == 1] = num_src_classes
        pseudo_labels[(rejected == 0) * (selected == 0)] = -1

        metrics = evaluate_T(params.oslpp_params.num_common, params.oslpp_params.num_src_priv,
                             params.oslpp_params.num_tgt_priv,
                             labels_T.numpy(), cs_pseudo_labels.numpy(), rejected.numpy())
        where = torch.where((selected == 1) + (rejected == 1))[0]
        metrics_selected = evaluate_T(params.oslpp_params.num_common, params.oslpp_params.num_src_priv,
                                      params.oslpp_params.num_tgt_priv,
                                      labels_T[where].numpy(), cs_pseudo_labels[where].numpy(),
                                      rejected[where].numpy())
        logger.log(t, metrics, metrics_selected)

    assert (pseudo_labels == -1).sum() == 0

    _rejected = rejected
    metrics = evaluate_T(params.oslpp_params.num_common, params.oslpp_params.num_src_priv,
                         params.oslpp_params.num_tgt_priv, labels_T.numpy(),
                         cs_pseudo_labels.numpy(), _rejected.numpy())
    logger.log_res(metrics)
    return metrics

In [None]:
results = DataFrameCollector({'source': [], 'target': [], 'desc': [], 'lr': [], 'seed': [], 'epochs': []})
select_reject_mode = SelectRejectMode.CONFIDENCE
logger = DefaultLogger()
for (source, target), (common, tgt_private) in config.items():
    for epochs in [10]:
        for lr in [1e-3]:
            for n_r in [0.1, 0.15, 0.25]:
                for seed in range(1):
                    params = Params(pca_dim=512, proj_dim=128, T=10, n_r=n_r, n_r_ratio=None,
                                    dataset='DomainNet_DCC', source=source, target=target,
                                    num_common=len(common), num_src_priv=0, num_tgt_priv=len(tgt_private))
                    ml_params = MlInfonceParams(oslpp_params=params,
                                         tau=0.1, lr=1e-3, batch_size=32, num_layers=2, epochs=10,
                                        pos=1, neg=30, pairwise_distance_fn='cos')
                    metrics = train(ml_params, select_reject_mode, seed, common, tgt_private, logger)