In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torchmetrics
from cnn_finetune import make_model
from torchvision.utils import save_image

import time
import numpy as np
import pandas as pd
from tqdm import tqdm

from torch.utils.data import DataLoader, Dataset, Subset
from wilds import get_dataset
from wilds.datasets.wilds_dataset import WILDSSubset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper

In [4]:
DOMAIN_NET_DOMAINS = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"]
SENTRY_DOMAINS = ["clipart", "painting", "real", "sketch"]
dataset = get_dataset(dataset='domainnet', download=False, root_dir = '/self/scr-sync/nlp/domainnet', use_sentry=False) # use_sentry denotes whether to use domainnet four or not
meta_array = dataset.metadata_array #the first column is domain and the second column is y

grouper = CombinatorialGrouper(dataset, ['domain']) # 6 groups
label_grouper = CombinatorialGrouper(dataset, ['y']) # 350 groups
full_grouper = CombinatorialGrouper(dataset, ['domain','y']) # 6 * 350 groups

In [8]:
group, group_counts = grouper.metadata_to_group(meta_array, return_counts=True)
print(group_counts / torch.sum(group_counts))

tensor([0.0821, 0.0880, 0.1232, 0.2941, 0.2948, 0.1179])


In [9]:
group, group_counts = label_grouper.metadata_to_group(meta_array, return_counts=True)
print(group_counts / torch.sum(group_counts))

tensor([0.0020, 0.0024, 0.0024, 0.0024, 0.0026, 0.0033, 0.0023, 0.0021, 0.0023,
        0.0027, 0.0033, 0.0025, 0.0031, 0.0030, 0.0026, 0.0030, 0.0021, 0.0024,
        0.0031, 0.0025, 0.0025, 0.0026, 0.0034, 0.0031, 0.0040, 0.0031, 0.0031,
        0.0026, 0.0029, 0.0036, 0.0027, 0.0040, 0.0024, 0.0025, 0.0031, 0.0030,
        0.0024, 0.0032, 0.0032, 0.0035, 0.0035, 0.0039, 0.0037, 0.0030, 0.0028,
        0.0021, 0.0026, 0.0032, 0.0022, 0.0036, 0.0025, 0.0029, 0.0018, 0.0015,
        0.0027, 0.0023, 0.0017, 0.0025, 0.0028, 0.0018, 0.0031, 0.0029, 0.0028,
        0.0028, 0.0034, 0.0015, 0.0022, 0.0024, 0.0021, 0.0021, 0.0024, 0.0029,
        0.0021, 0.0026, 0.0026, 0.0032, 0.0018, 0.0022, 0.0024, 0.0023, 0.0022,
        0.0026, 0.0028, 0.0025, 0.0029, 0.0019, 0.0031, 0.0033, 0.0026, 0.0022,
        0.0025, 0.0044, 0.0031, 0.0031, 0.0029, 0.0026, 0.0016, 0.0024, 0.0032,
        0.0030, 0.0033, 0.0024, 0.0024, 0.0039, 0.0031, 0.0019, 0.0038, 0.0030,
        0.0031, 0.0020, 0.0040, 0.0029, 