In [1]:
import torch
import torch.nn as nn
import os
import pandas as pd
import pickle

In [2]:
with open('small_datasets.pkl', 'rb') as f:
    config = pickle.load(f)

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [None]:
from modules.loaders.balanced_sampling import create_train_dataloader
from modules.loaders.resnet_features.features_dataset import FeaturesDataset


def train_NN_N_plus_One(feats, lbls, num_epochs, params, balanced, lr):
    num_src_classes = params.num_common + params.num_src_priv
    feats, lbls = feats[lbls >= 0], lbls[lbls >= 0]
    feats, lbls = feats[lbls < num_src_classes + 1], lbls[lbls < num_src_classes + 1]
    model = nn.Sequential(nn.Linear(params.pca_dim, params.proj_dim), nn.ReLU(), nn.Linear(params.proj_dim, num_src_classes + 1)).cuda().train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss().cuda()
    ds = FeaturesDataset(feats, lbls)
    dl = create_train_dataloader(ds, 32, balanced)
    for ep in range(num_epochs):
        for f,l in dl:
            optimizer.zero_grad()
            loss_fn(model(f.cuda()), l.cuda()).backward()
            optimizer.step()
    return model.eval()

In [None]:
from modules.logging.format_utils import format_measures
from modules.algorithms.nn.OSLPP_NN_UTILS import train_NN, predict_NN
from modules.loaders.osda import create_datasets_sub
from modules.algorithms.base.OSLPP import Params, do_l2_normalization, do_pca, select_closed_set_pseudo_labels, \
    evaluate_T
from modules.loaders.common import set_seed
from modules.selection.uncertanties import SelectRejectMode, select_initial_rejected

results = pd.DataFrame({'source': [], 'target': [], 'desc': [], 'lr': [], 'seed': [], 'epochs': []})
select_reject_mode = SelectRejectMode.CONFIDENCE
for (source, target), (common, tgt_private) in config.items():
    for epochs in [10, 50]:
        for lr in [1e-3, 1e-4]:
            for n_r in [0.1, 0.25, 0.35]:
                for seed in range(3):
                    set_seed(seed)
                    print(source, '->', target, 'lr=', lr, 'seed=', seed)
                    params = Params(pca_dim=512, proj_dim=128, T=10, n_r=1200, n_r_ratio=None,
                              dataset='DomainNet_DCC', source=source, target=target,
                              num_common=len(common), num_src_priv=0, num_tgt_priv=len(tgt_private))
                    (feats_S, labels_S), (feats_T, labels_T) = create_datasets_sub(params.dataset,
                                                                                   params.source,
                                                                                   params.target,
                                                                                   common,
                                                                                   tgt_private)
                    # params.n_r = int(len(labels_T) * n_r)
                    num_src_classes = params.num_common + params.num_src_priv

                    # l2 normalization and pca
                    feats_S, feats_T = do_l2_normalization(feats_S, feats_T)
                    feats_S, feats_T = do_pca(feats_S, feats_T, params.pca_dim)
                    feats_S, feats_T = do_l2_normalization(feats_S, feats_T)

                    # initial
                    feats_S, feats_T = torch.tensor(feats_S), torch.tensor(feats_T)
                    labels_S, labels_T = torch.tensor(labels_S), torch.tensor(labels_T)
                    feats_all = torch.cat((feats_S, feats_T), dim=0)

                    t = 1

                    model = train_NN(feats_S, labels_S, epochs, params, balanced=True, lr=lr, initial=True)
                    feats_S_2, predictions_S_2 = predict_NN(model, feats_S, labels_S)
                    feats_T_2, predictions_T_2 = predict_NN(model, feats_T, labels_T)

                    assert (feats_S_2 == feats_S).all() and (feats_T_2 == feats_T).all()

                    confs, cs_pseudo_labels = predictions_T_2.max(dim=1)
                    selected = torch.tensor(select_closed_set_pseudo_labels(cs_pseudo_labels.numpy(), confs.numpy(), predictions_T_2,
                                                                            t, params.T,
                                                                            mode=select_reject_mode))
                    rejected = torch.tensor(select_initial_rejected(confs, predictions_T_2, params.n_r, mode=select_reject_mode))
                    selected = selected * (1 - rejected)

                    pseudo_labels = cs_pseudo_labels.clone()
                    pseudo_labels[rejected == 1] = num_src_classes
                    pseudo_labels[(rejected == 0) * (selected == 0)] = -1

                    metrics = evaluate_T(params.num_common, params.num_src_priv, params.num_tgt_priv,
                                         labels_T.numpy(), cs_pseudo_labels.numpy(), rejected.numpy())
                    where = torch.where((selected == 1) + (rejected == 1))[0]
                    metrics_selected = evaluate_T(params.num_common, params.num_src_priv, params.num_tgt_priv,
                                                  labels_T[where].numpy(), cs_pseudo_labels[where].numpy(),
                                                  rejected[where].numpy())
                    print('______')
                    print(f'Iteration t={t}')
                    print('all: ', format_measures(metrics))
                    print('selected: ', format_measures(metrics_selected))

                    for t in range(2, params.T):
                        model = train_NN_N_plus_One(feats_all, torch.cat((labels_S, pseudo_labels), axis=0), epochs, params, balanced=True, lr=lr)
                        feats_S_2, preds_S_2 = predict_NN(model, feats_S, labels_S)
                        feats_T_2, preds_T_2 = predict_NN(model, feats_T, labels_T)
                        assert (feats_S_2 == feats_S).all() and (feats_T_2 == feats_T).all()

                        confs, cs_pseudo_labels = preds_T_2.max(dim=1)
                        selected = torch.tensor(select_closed_set_pseudo_labels(cs_pseudo_labels.numpy(), confs.numpy(), t, params.T, preds_T_2, select_reject_mode))
                        pseudo_labels = cs_pseudo_labels.clone()
                        pseudo_labels[selected == 0] = -1
                        print('Selected CNT:', len(selected[selected != 0]), '/', len(selected))


                        rejected = torch.zeros_like(pseudo_labels)
                        rejected[pseudo_labels == num_src_classes] = 1

                        metrics = evaluate_T(params.num_common, params.num_src_priv, params.num_tgt_priv, labels_T.numpy(), cs_pseudo_labels.numpy(), rejected.numpy())
                        where = torch.where((selected == 1) + (rejected == 1))[0]
                        metrics_selected = evaluate_T(params.num_common, params.num_src_priv, params.num_tgt_priv, labels_T[where].numpy(), cs_pseudo_labels[where].numpy(), rejected[where].numpy())
                        print('______')
                        print(f'Iteration t={t}')
                        print('all: ', format_measures(metrics))
                        print('selected: ', format_measures(metrics_selected))

                    print((pseudo_labels == -1).sum())

                    _rejected = rejected
                    metrics = evaluate_T(params.num_common, params.num_src_priv, params.num_tgt_priv, labels_T.numpy(), cs_pseudo_labels.numpy(), _rejected.numpy())
                    print('all: ', format_measures(metrics))

                    results = results.append({'source': source, 'target': target, 'desc': format_measures(metrics), 'lr': lr, 'seed': seed, 'n_r': n_r, 'epochs': epochs}, ignore_index=True)

In [13]:
results.to_csv('../all_results/results/dcc__conf__small__NN_N+1__nn_raw.csv', header=True, index=False)