In [2]:
import os, csv
import argparse
import pandas as pd
import torch
import torch.nn as nn
import torchvision

from configs.model_config import model_attributes
from data import dataset_attributes, shift_types, prepare_data, log_data
from utils.train_utils import set_seed, Logger, CSVBatchLogger, log_args
from train import train
from tqdm.notebook import tqdm
import numpy as np

cpu


In [3]:
def check_args(args):
    if args.shift_type == 'confounder':
        assert args.confounder_names
        assert args.target_name
    elif args.shift_type.startswith('label_shift'):
        assert args.minority_fraction
        assert args.imbalance_ratio

In [11]:
parser = argparse.ArgumentParser()

# Settings
parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)
parser.add_argument('-s', '--shift_type', choices=shift_types, default='confounder')
# Confounders
parser.add_argument('-t', '--target_name')
parser.add_argument('-c', '--confounder_names', nargs='+')
# Resume?
parser.add_argument('--resume', default=False, action='store_true')
# Label shifts
parser.add_argument('--minority_fraction', type=float)
parser.add_argument('--imbalance_ratio', type=float)
# Data
parser.add_argument('--fraction', type=float, default=1.0)
parser.add_argument('--root_dir', default=None)
parser.add_argument('--reweight_groups', action='store_true', default=False)
parser.add_argument('--augment_data', action='store_true', default=False)
parser.add_argument('--val_fraction', type=float, default=0.1)
# Objective
parser.add_argument('--robust', default=False, action='store_true')
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--generalization_adjustment', default="0.0")
parser.add_argument('--automatic_adjustment', default=False, action='store_true')
parser.add_argument('--robust_step_size', default=0.01, type=float)
parser.add_argument('--use_normalized_loss', default=False, action='store_true')
parser.add_argument('--btl', default=False, action='store_true')
parser.add_argument('--hinge', default=False, action='store_true')

# Model
parser.add_argument(
    '--model',
    choices=model_attributes.keys(),
    default='resnet50')
parser.add_argument('--train_from_scratch', action='store_true', default=False)

# Optimization
parser.add_argument('--n_epochs', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--scheduler', action='store_true', default=False)
parser.add_argument('--weight_decay', type=float, default=5e-5)
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--minimum_variational_weight', type=float, default=0)
# Misc
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--show_progress', default=False, action='store_true')
parser.add_argument('--log_dir', default='../inv-feature/logs/')
parser.add_argument('--log_every', default=1e8, type=int)
parser.add_argument('--save_step', type=int, default=1e8)
parser.add_argument('--save_best', action='store_true', default=False)
parser.add_argument('--save_last', action='store_true', default=False)

multinli_command = ['-s', 'confounder', '-d', 'MultiNLI', '-t', 'gold_label_random',
           '-c', 'sentence2_has_negation', '--batch_size', '32', '--model', 'bert',
           '--n_epochs', '3', '--seed', '0']
celeba_command = ['-d', 'CelebA', '-t', 'Blond_Hair', '-c', 'Male', '--model', 'resnet50',
                  '--weight_decay', '0.01', '--lr', '0.0001',
                   "--batch_size", '128', '--n_epochs', '50']
waterbird_command = ['-d', 'CUB', '-t', 'waterbird_complete95', '-c', 'forest2water2', 
                     '--model', 'resnet50', '--weight_decay', '0.1', '--lr', '0.0001',
                '--batch_size', '128', '--n_epochs', '300']
command = multinli_command
command += ['--seed', '0']
args = parser.parse_args(args=command)
check_args(args)


In [8]:
torch.cuda.set_device(0)
device = torch.device('cuda')
preprocess = None

AttributeError: module 'torch._C' has no attribute '_cuda_setDevice'

In [9]:
import clip
model, preprocess = clip.load('ViT-B/32', 'cpu') 

100%|███████████████████████████████████████| 338M/338M [00:17<00:00, 20.2MiB/s]


In [None]:
import open_clip

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')

In [12]:
if args.model == 'bert':
    args.max_grad_norm = 1.0
    args.adam_epsilon = 1e-8
    args.warmup_steps = 0

if args.robust:
    algo = 'groupDRO'
elif args.reweight_groups:
    algo = 'reweight'
else:
    algo = 'ERM'

args.log_dir = os.path.join(args.log_dir, args.dataset, algo, f's{args.seed}')

if os.path.exists(args.log_dir) and args.resume:
    resume=True
    mode='a'
else:
    resume=False
    mode='w'

## Initialize logs
if not os.path.exists(args.log_dir):
    os.makedirs(args.log_dir)
set_seed(args.seed)
# Data
# Test data for label_shift_step is not implemented yet
test_data = None
test_loader = None
if args.shift_type == 'confounder':
    train_data, val_data, test_data = prepare_data(args, train=True,train_transform=preprocess,eval_transform=preprocess)
elif args.shift_type == 'label_shift_step':
    train_data, val_data = prepare_data(args, train=True)

loader_kwargs = {'batch_size':args.batch_size, 'num_workers':4, 'pin_memory':True}
train_loader = train_data.get_loader(train=True, reweight_groups=args.reweight_groups, **loader_kwargs)
val_loader = val_data.get_loader(train=False, reweight_groups=None, **loader_kwargs)
if test_data is not None:
    test_loader = test_data.get_loader(train=False, reweight_groups=None, **loader_kwargs)

data = {}
data['train_loader'] = train_loader
data['val_loader'] = val_loader
data['test_loader'] = test_loader
data['train_data'] = train_data
data['val_data'] = val_data
data['test_data'] = test_data
n_classes = train_data.n_classes

In [13]:
n_classes

3

In [15]:
## Initialize model
pretrained = not args.train_from_scratch
if resume:
    model = torch.load(os.path.join(args.log_dir, 'last_model.pth'))
    d = train_data.input_size()[0]
elif model_attributes[args.model]['feature_type'] in ('precomputed', 'raw_flattened'):
    assert pretrained
    # Load precomputed features
    d = train_data.input_size()[0]
    model = nn.Linear(d, n_classes)
    model.has_aux_logits = False
elif args.model == 'resnet50':
    model = torchvision.models.resnet50(pretrained=pretrained)
    d = model.fc.in_features
    model.fc = nn.Linear(d, n_classes)
elif args.model == 'resnet34':
    model = torchvision.models.resnet34(pretrained=pretrained)
    d = model.fc.in_features
    model.fc = nn.Linear(d, n_classes)
elif args.model == 'wideresnet50':
    model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
    d = model.fc.in_features
    model.fc = nn.Linear(d, n_classes)
elif args.model == 'bert':
    assert args.dataset == 'MultiNLI'

    from pytorch_transformers import BertConfig, BertForSequenceClassification
    config_class = BertConfig
    model_class = BertForSequenceClassification

    config = config_class.from_pretrained(
        'bert-base-uncased',
        num_labels=3,
        finetuning_task='mnli')
    model = model_class.from_pretrained(
        'bert-base-uncased',
        from_tf=False,
        config=config)
else:
    raise ValueError('Model not recognized.')

100%|██████████| 433/433 [00:00<00:00, 733791.37B/s]
100%|██████████| 440473133/440473133 [01:09<00:00, 6320360.83B/s] 


In [16]:
use_clip = type(model).__name__ == 'CLIP'
save_prefix = 'CLIP_' if use_clip else ''
load_ckpt = not use_clip

In [17]:
load_ckpt = False

In [18]:
# model = model.to(device)

if use_clip:
    encoder = model.encode_image
    output_layer = None
elif (not args.model.startswith('bert')): 
    encoder = torch.nn.Sequential(*(list(model.children())[:-1] + [torch.nn.Flatten()]))
    output_layer = model.fc



def process_batch(model, x, y = None, g = None, bert = True):
    if bert:
        input_ids = x[:, :, 0]
        input_masks = x[:, :, 1]
        segment_ids = x[:, :, 2]
        outputs = model.bert(
                input_ids=input_ids,
                attention_mask=input_masks,
                token_type_ids=segment_ids,
            )
        pooled_output = outputs[1]
        logits = model.classifier(pooled_output)
        result = {'feature':pooled_output.detach().cpu().numpy(),
                  'pred': np.argmax(logits.detach().cpu().numpy(), axis=1),
                 }
    else:
        features = encoder(x)
        result = {'feature':features.detach().cpu().numpy(),}        
        if output_layer is not None:
            logits = output_layer(features)
            result['pred'] = np.argmax(logits.detach().cpu().numpy(), axis=1),
    if y is not None: result['label'] = y.detach().cpu().numpy()
    if g is not None: result['group'] = g.detach().cpu().numpy()
    return result


In [11]:
load_ckpt

False

In [None]:
import pickle
from itertools import product
algos = ['ERM', ]
# algos = ['ERM']
model_selects = ['init']
seeds = np.arange(10)
for algo, model_select, seed in tqdm(list(product(algos,model_selects, seeds)),desc='Iter'):
    print('Current iter:',algo, model_select, seed)
    save_dir = f'/data/common/inv-feature/logs/{args.dataset}/{algo}/s{seed}/'
    if load_ckpt:
        model.load_state_dict(torch.load(save_dir + f'/{model_select}_model.pth',
                     map_location='cpu').state_dict())

    model.eval()
    for split,loader in zip(['train', 'val', 'test'], [train_loader, val_loader, test_loader]):
        results = []
        fname = f'{split}_data.p'
        fname = save_prefix + model_select + '_' + fname 
        if os.path.exists(save_dir + '/' + fname):
            continue
        with torch.set_grad_enabled(False):
            for batch_idx, batch in enumerate(tqdm(loader)):
                batch = tuple(t.to(device) for t in batch)
                x = batch[0]
                y = batch[1]
                g = batch[2]
                if args.model.startswith("bert"):
                    result = process_batch(model, x, y, g, bert=True)
                else:
                    result = process_batch(model, x, y, g, bert=False)
                results.append(result)
        parsed_data = {}
        for key in results[0].keys():
            parsed_data[key] = np.concatenate([result[key] for result in results])
        
        pickle.dump(parsed_data, open(save_dir + '/' + fname, 'wb'))

        del results
        del parsed_data

Iter:   0%|          | 0/10 [00:00<?, ?it/s]

Current iter: ERM init 0


  0%|          | 0/6443 [00:00<?, ?it/s]