In [None]:
import torch
import torch.nn as nn
import os
import pandas as pd
import pickle
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset

In [None]:
with open('experiments/configs/small_datasets.pkl', 'rb') as f:
    config = pickle.load(f)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [None]:
from modules.deep.DEEPNN import get_simple_deep_nn, fix_batch_normalization_layers


def get_deep(_, n_classes):
    m = get_simple_deep_nn(n_classes).train()
    m.apply(fix_batch_normalization_layers)
    return m.cuda()

In [None]:
class FeaturesDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self): return len(self.labels)

    def __getitem__(self, i): return self.features[i], self.labels[i]

In [None]:
from modules.loaders.balanced_sampling import create_train_dataloader


def train_NN(feats_S, lbls_S, num_epochs, params, balanced, lr):
    num_src_classes = params.num_common + params.num_src_priv
    
    model = get_deep(params, num_src_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss().cuda()
    ds = FeaturesDataset(feats_S, lbls_S)
    dl = create_train_dataloader(ds, 32, balanced)
    print('CLASSES:', num_src_classes)
    for _ in (pbar := tqdm(range(num_epochs))):
        avg_loss = 0
        for f,l in dl:
            optimizer.zero_grad()
            preds = model(f.cuda())
            loss = loss_fn(preds, l.cuda())
            loss.backward()
            optimizer.step()
            avg_loss += loss.item()
        pbar.set_description('Loss: ' + str(avg_loss / len(dl)))
    return model.eval()

In [None]:
def predict_NN(model, feats, lbls):
    ds = FeaturesDataset(feats, lbls)
    dl = DataLoader(ds, batch_size=32, shuffle=False)
    model = model.eval()
    preds = []
    with torch.no_grad():
        for f,l in dl:
            out = model(f.cuda())
            preds.append(F.softmax(out, dim=1).detach().cpu())
    preds = torch.cat(preds, dim=0)
    return preds

In [None]:
from torch.utils.data import Subset
from modules.loaders.osda import create_datasets_sub
from modules.algorithms.base.OSLPP import Params, do_l2_normalization, do_pca
from modules.loaders.common import set_seed

results = pd.DataFrame({'source': [],  'lr': [], 'seed': [], 'epochs': [], 'report': []})
used_src = set()
splitter = StratifiedKFold(n_splits=5)
for (source, target), (common, tgt_private) in config.items():
    if source in used_src:
        continue
    used_src.add(source)
    for epochs in [50]:
        for lr in [1e-5]:
            for n_r in [0.1, 0.15, 0.25]:
                set_seed(0)
                print(source, '->', target, 'lr=', lr, 'seed=', 0)
                params = Params(pca_dim=512, proj_dim=128, T=10, n_r=1200, n_r_ratio=None,
                          dataset='DomainNet_DCC', source=source, target=target,
                          num_common=len(common), num_src_priv=0, num_tgt_priv=len(tgt_private))
                (feats_S, labels_S), (_, _) = create_datasets_sub(params.dataset,
                                                                                   params.source,
                                                                                   params.target,
                                                                                   common,
                                                                                   tgt_private,
                                                                                   is_images=True)
                metrics = train_osda(params, lr, epochs, 1, select_reject_mode, seed, None, None, logger, osda=False)
                        results.collect({'source': source, 'target': target, 'desc': format_measures(metrics), 'lr': lr, 'seed': seed, 'n_r': n_r, 'epochs': epochs})
                num_src_classes = params.num_common + params.num_src_priv

#                 for i, (train_idx, test_idx) in enumerate(splitter.split(feats_S, labels_S)):
#                     feats_S_train, lbls_S_train = feats_S[train_idx], labels_S[train_idx]
#                     feats_S_test, lbls_S_test = feats_S[test_idx], labels_S[test_idx]


#                     model = train_NN(feats_S_train, lbls_S_train, epochs, params, balanced=True, lr=lr)
#                     preds_S_2 = predict_NN(model, feats_S_test, lbls_S_test)

#                     confs, preds_labels = preds_S_2.max(dim=1)

#                     np_preds = preds_labels.numpy()
#                     np_labels = lbls_S_test
                    
#                     report = classification_report(np_labels, np_preds, output_dict=True)

                    results = results.append({'source': source, 'lr': lr, 'seed': i, 'epochs': epochs, 'report': report}, ignore_index=True)
                del feats_S, labels_S

In [None]:
results.to_csv('deep_src_classifier_3.csv', header=True, index=False)