In [90]:
import os
import pickle
import torch
from modules.augmentation.transformations import *

In [22]:
with open('experiments/configs/small_datasets.pkl', 'rb') as f:
    config = pickle.load(f)

In [23]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [None]:
sources = ['painting', 'real', 'sketch']

In [None]:
from torch import nn
from torch.utils.data import DataLoader
from modules.loaders.resnet_features.features_dataset import FeaturesDataset
from modules.deep.DEEPNN import get_feature_extractor_model
from modules.loaders.osda import _load_tensors_sub_images
from tqdm import tqdm

used_targets = set()
dataset = 'DomainNet_DCC'
for (source, target), (common, tgt_private) in config.items():
    if target in used_targets:
        continue
    used_targets.add(target)
    for transform in [Augmentation.TRANSLATE, Augmentation.ROTATE, Augmentation.FLIP, Augmentation.SCALE, Augmentation.AFFINE_DEGREES, Augmentation.BRIGHTNESS, Augmentation.CONTRAST, Augmentation.BLUR]:
        tgt_features_raw, tgt_labels_raw = _load_tensors_sub_images(dataset, target, 'test',
                                                                common,
                                                                tgt_private,
                                                                None)
        tgt_features_aug, tgt_labels_aug = _load_tensors_sub_images(dataset, target, 'test',
                                                                common,
                                                                tgt_private,
                                                                transform)

        tgt_dataset_raw = FeaturesDataset(tgt_features_raw, tgt_labels_raw)
        tgt_dataset_aug = FeaturesDataset(tgt_features_aug, tgt_labels_aug)

        tgt_loader_raw = DataLoader(tgt_dataset_raw, batch_size=32, shuffle=False)
        tgt_loader_aug = DataLoader(tgt_dataset_aug, batch_size=32, shuffle=False)

        resnet_model = get_feature_extractor_model()
        resnet_model = nn.Sequential(*resnet_model, nn.Flatten()).cuda()

        raw_fs = []
        raw_labels = []
        for (features, labels) in tqdm(tgt_loader_raw):
            features = features.cuda()
            labels = labels.detach().cpu()

            with torch.no_grad():
                out_features = resnet_model(features).detach().cpu()

            raw_fs.append(out_features)
            raw_labels.append(labels)

        all_raw_fs = torch.cat(raw_fs, dim=0)
        all_raw_ls = torch.cat(raw_labels, dim=0)

        directory = f'./features/DomainNet_DCC_aug/{transform.name}'
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save((all_raw_fs, all_raw_ls), f'{directory}/nocrop_{target}_test_raw.plk')

        aug_fs = []
        aug_labels = []
        for (features, labels) in tqdm(tgt_loader_aug):
            features = features.cuda()
            labels = labels.detach().cpu()

            with torch.no_grad():
                out_features = resnet_model(features).detach().cpu()

            aug_fs.append(out_features)
            aug_labels.append(labels)

        all_aug_fs = torch.cat(aug_fs, dim=0)
        all_aug_labels = torch.cat(aug_labels, dim=0)

        torch.save((all_aug_fs, all_aug_labels), f'./features/DomainNet_DCC_aug/{transform.name}/nocrop_{target}_test_aug.plk')