In [1]:
import os
import pickle 
from tqdm import tqdm

import sys
sys.path.append('/user/yz3587/llm-dro/code/embedding/data_processing/')
from embed import *
from train import *

from src.mlp_concat import *
from src.mlp_e5 import *
from src.mlp import *

In [2]:
s = 'AL,AK,AZ,AR,CA,CO,CT,DE,FL,GA,HI,ID,IL,IN,IA,KS,KY,LA,ME,MD,MA,MI,MN,MS,MO,MT,NE,NV,NH,NJ,NM,NY,NC,ND,OH,OK,OR,PA,RI,SC,SD,TN,TX,UT,VT,VA,WA,WV,WI,WY,PR'
all_states = s.split(',')
state_to_idx = {state: idx for idx, state in enumerate(all_states)}

state_dict = {
    'AL': 'Alabama','AK': 'Alaska','AZ': 'Arizona','AR': 'Arkansas','CA': 'California','CO': 'Colorado','CT': 'Connecticut',
    'DE': 'Delaware','FL': 'Florida','GA': 'Georgia','HI': 'Hawaii','ID': 'Idaho','IL': 'Illinois','IN': 'Indiana','IA': 'Iowa',
    'KS': 'Kansas','KY': 'Kentucky','LA': 'Louisiana','ME': 'Maine','MD': 'Maryland','MA': 'Massachusetts','MI': 'Michigan',
    'MN': 'Minnesota','MS': 'Mississippi','MO': 'Missouri','MT': 'Montana','NE': 'Nebraska','NV': 'Nevada','NH': 'New Hampshire',
    'NJ': 'New Jersey','NM': 'New Mexico','NY': 'New York','NC': 'North Carolina','ND': 'North Dakota','OH': 'Ohio','OK': 'Oklahoma',
    'OR': 'Oregon','PA': 'Pennsylvania','RI': 'Rhode Island','SC': 'South Carolina','SD': 'South Dakota','TN': 'Tennessee','TX': 'Texas',
    'UT': 'Utah','VT': 'Vermont','VA': 'Virginia','WA': 'Washington','WV': 'West Virginia','WI': 'Wisconsin','WY': 'Wyoming',
    'PR': 'Puerto Rico'
}

## example

In [3]:
# example
source_state = 'HI'
source_state_list = source_state.split(' ')

In [31]:
def load_data(method, state, root_dir, num_train=20000, num_val=500, num_test=5000):
    if method == 'concat':
        X, y = get_concat_data('income', state, root_dir)
    elif method == 'e5':
        X, y = get_e5_data('income', 'domainlabel', state, root_dir)
    elif method == 'one_hot':
        X, y = get_onehot_data('income', state, False, root_dir, year=2018)

    n = X.shape[0]
    if n < num_train + num_val + num_test:
        num_train = int(num_train * n / (num_train + num_val + num_test))
        num_val = int(num_val * n / (num_train + num_val + num_test))
        num_test = n - num_train - num_val

    # Setting the random seed to ensure reproducibility
    np.random.seed(42)  # You can use any number here as your seed
    # Combining the data into a single array for shuffling
    data = np.column_stack((X, y))
    np.random.shuffle(data)
    # Splitting the data back into features and labels
    X, y = data[:, :-1], data[:, -1]
    # Splitting the data into train, validation, and test sets
    trainx, trainy = X[:num_train], y[:num_train]
    valx, valy = X[num_train:num_train+num_val], y[num_train:num_train+num_val]
    testx, testy = X[-num_test:], y[-num_test:]
    return trainx, trainy, valx, valy, testx, testy

## One hot baseline

In [22]:
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
    # concat all train data into one
    if state == source_state_list[0]:
        X_train = trainx
        y_train = trainy
    else:
        X_train = np.concatenate((X_train, trainx), axis=0)
        y_train = np.concatenate((y_train, trainy), axis=0)

In [23]:
X_train.shape

(6063, 4096)

In [24]:
save_dir = '/shared/share_mala/llm-dro/'
task_name = 'income'
source_state_str = 'CA'
embedding_method = 'e5'
prompt_method = 'domainlabel'
model_dir = f'{save_dir}/save_models/{task_name}/{source_state_str}/{embedding_method}/{prompt_method}/'

In [32]:
## train model
domainlabel_model = MLPe5Classifier(input_dim=X_train.shape[1], num_classes = 2, hidden_dim=64)
y_train = torch.tensor(y_train).long()
domainlabel_model.load(5, model_dir)

  y_train = torch.tensor(y_train).long()


In [33]:
domainlabel_source_train_acc_dict, domainlabel_source_val_acc_dict, domainlabel_source_test_acc_dict = dict(), dict(), dict()
domainlabel_source_train_f1_dict, domainlabel_source_val_f1_dict, domainlabel_source_test_f1_dict = dict(), dict(), dict()

## report training, val, testing performance for each state
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
    train_acc, train_f1 = domainlabel_model.score(trainx, trainy)
    val_acc, val_f1 = domainlabel_model.score(valx, valy)
    test_acc, test_f1 = domainlabel_model.score(testx, testy)
    
    domainlabel_source_train_acc_dict[state], domainlabel_source_val_acc_dict[state], domainlabel_source_test_acc_dict[state] = train_acc, val_acc, test_acc
    domainlabel_source_train_f1_dict[state], domainlabel_source_val_f1_dict[state], domainlabel_source_test_f1_dict[state] = train_f1, val_f1, test_f1

print("Source State: ")
print(f"average train acc: {np.mean(list(domainlabel_source_train_acc_dict.values())):.3f}, average train f1: {np.mean(list(domainlabel_source_train_f1_dict.values())):.3f}")
print(f"average val acc: {np.mean(list(domainlabel_source_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(domainlabel_source_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(domainlabel_source_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(domainlabel_source_test_f1_dict.values())):.3f}")

Source State: 
average train acc: 0.743, average train f1: 0.739
average val acc: 0.781, average val f1: 0.780
average test acc: 0.740, average test f1: 0.739


In [36]:
domainlabel_target_val_acc_dict, domainlabel_target_test_acc_dict =  dict(), dict()
domainlabel_target_val_f1_dict, domainlabel_target_test_f1_dict =  dict(), dict()
for state in ['IA']:
    if state not in source_state_list:
        root_dir = '/shared/share_mala/llm-dro/income/'
        trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
        val_acc, val_f1 = domainlabel_model.score(valx, valy)
        test_acc, test_f1 = domainlabel_model.score(testx, testy)
        
        domainlabel_target_val_acc_dict[state], domainlabel_target_test_acc_dict[state] = val_acc, test_acc
        domainlabel_target_val_f1_dict[state], domainlabel_target_test_f1_dict[state] = val_f1, test_f1

print("Target State: ")
print(f"average val acc: {np.mean(list(domainlabel_target_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(domainlabel_target_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(domainlabel_target_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(domainlabel_target_test_f1_dict.values())):.3f}")

Target State: 
average val acc: 0.735, average val f1: 0.729
average test acc: 0.694, average test f1: 0.687


In [39]:
## refit one hot model on one target state

target_state = 'IA'
root_dir = '/shared/share_mala/llm-dro/income/'
trainx, trainy, valx, valy, testx, testy = load_data('e5', target_state, root_dir, num_train=20000, num_val=32, num_test=5000)

refitx = valx
refity = valy

In [40]:
## train model
refity = torch.tensor(refity).long()

domainlabel_model.refit_epochs = 50
domainlabel_model.refit_lr = 0.001
domainlabel_model.refit(refitx, refity)

  y = torch.tensor(y)
100%|██████████| 51/51 [00:22<00:00,  2.31it/s]


In [41]:
domainlabel_source_train_acc_dict, domainlabel_source_val_acc_dict, domainlabel_source_test_acc_dict = dict(), dict(), dict()
domainlabel_source_train_f1_dict, domainlabel_source_val_f1_dict, domainlabel_source_test_f1_dict = dict(), dict(), dict()

## report training, val, testing performance for each state
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
    train_acc, train_f1 = domainlabel_model.score(trainx, trainy)
    val_acc, val_f1 = domainlabel_model.score(valx, valy)
    test_acc, test_f1 = domainlabel_model.score(testx, testy)
    
    domainlabel_source_train_acc_dict[state], domainlabel_source_val_acc_dict[state], domainlabel_source_test_acc_dict[state] = train_acc, val_acc, test_acc
    domainlabel_source_train_f1_dict[state], domainlabel_source_val_f1_dict[state], domainlabel_source_test_f1_dict[state] = train_f1, val_f1, test_f1

print("Source State: ")
print(f"average train acc: {np.mean(list(domainlabel_source_train_acc_dict.values())):.3f}, average train f1: {np.mean(list(domainlabel_source_train_f1_dict.values())):.3f}")
print(f"average val acc: {np.mean(list(domainlabel_source_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(domainlabel_source_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(domainlabel_source_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(domainlabel_source_test_f1_dict.values())):.3f}")

Source State: 
average train acc: 0.715, average train f1: 0.653
average val acc: 0.698, average val f1: 0.645
average test acc: 0.692, average test f1: 0.640


In [42]:
domainlabel_target_val_acc_dict, domainlabel_target_test_acc_dict =  dict(), dict()
domainlabel_target_val_f1_dict, domainlabel_target_test_f1_dict =  dict(), dict()
for state in ['IA']:
    if state not in source_state_list:
        root_dir = '/shared/share_mala/llm-dro/income/'
        trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
        val_acc, val_f1 = domainlabel_model.score(valx, valy)
        test_acc, test_f1 = domainlabel_model.score(testx, testy)
        
        domainlabel_target_val_acc_dict[state], domainlabel_target_test_acc_dict[state] = val_acc, test_acc
        domainlabel_target_val_f1_dict[state], domainlabel_target_test_f1_dict[state] = val_f1, test_f1

print("Target State: ")
print(f"average val acc: {np.mean(list(domainlabel_target_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(domainlabel_target_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(domainlabel_target_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(domainlabel_target_test_f1_dict.values())):.3f}")

Target State: 
average val acc: 0.732, average val f1: 0.668
average test acc: 0.713, average test f1: 0.622


## llm baseline

In [5]:
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
    # concat all train data into one
    if state == source_state_list[0]:
        X_train = trainx
        y_train = trainy
    else:
        X_train = np.concatenate((X_train, trainx), axis=0)
        y_train = np.concatenate((y_train, trainy), axis=0)

In [6]:
## train model
baseline_model = MLPe5Classifier(input_dim=4096, num_classes = 2, hidden_dim=64)
baseline_model.train_epochs = 200
baseline_model.fit(X_train, y_train)

100%|██████████| 201/201 [03:18<00:00,  1.01it/s]


In [7]:
baseline_source_train_acc_dict, baseline_source_val_acc_dict, baseline_source_test_acc_dict = dict(), dict(), dict()
baseline_source_train_f1_dict, baseline_source_val_f1_dict, baseline_source_test_f1_dict = dict(), dict(), dict()

## report training, val, testing performance for each state
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
    train_acc, train_f1 = baseline_model.score(trainx, trainy)
    val_acc, val_f1 = baseline_model.score(valx, valy)
    test_acc, test_f1 = baseline_model.score(testx, testy)
    
    baseline_source_train_acc_dict[state], baseline_source_val_acc_dict[state], baseline_source_test_acc_dict[state] = train_acc, val_acc, test_acc
    baseline_source_train_f1_dict[state], baseline_source_val_f1_dict[state], baseline_source_test_f1_dict[state] = train_f1, val_f1, test_f1

print("Source State: ")
print(f"average train acc: {np.mean(list(baseline_source_train_acc_dict.values())):.3f}, average train f1: {np.mean(list(baseline_source_train_f1_dict.values())):.3f}")
print(f"average val acc: {np.mean(list(baseline_source_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(baseline_source_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(baseline_source_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(baseline_source_test_f1_dict.values())):.3f}")


Source State: 
average train acc: 0.843, average train f1: 0.830
average val acc: 0.814, average val f1: 0.804
average test acc: 0.820, average test f1: 0.803


In [8]:
baseline_target_val_acc_dict, baseline_target_test_acc_dict =  dict(), dict()
baseline_target_val_f1_dict, baseline_target_test_f1_dict =  dict(), dict()
for state in all_states:
    if state not in source_state_list:
        root_dir = '/shared/share_mala/llm-dro/income/'
        trainx, trainy, valx, valy, testx, testy = load_data('e5', state, root_dir)
        val_acc, val_f1 = baseline_model.score(valx, valy)
        test_acc, test_f1 = baseline_model.score(testx, testy)
        
        baseline_target_val_acc_dict[state], baseline_target_test_acc_dict[state] = val_acc, test_acc
        baseline_target_val_f1_dict[state], baseline_target_test_f1_dict[state] = val_f1, test_f1

print("Target State: ")
print(f"average val acc: {np.mean(list(baseline_target_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(baseline_target_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(baseline_target_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(baseline_target_test_f1_dict.values())):.3f}")

Target State: 
average val acc: 0.738, average val f1: 0.729
average test acc: 0.728, average test f1: 0.718


## concat embedding

In [9]:
'''
# train the model
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('concat', state, root_dir)
    # concat all train data into one
    if state == source_state_list[0]:
        X_train = trainx
        y_train = trainy
    else:
        X_train = np.concatenate((X_train, trainx), axis=0)
        y_train = np.concatenate((y_train, trainy), axis=0)

## train model
model = MLPconcatClassifier(input_dim=4096, num_classes = 2, hidden_dim=64, refit_method='pca', initial_embedding_method='wiki')
model.train_epochs = 100
model.fit(X_train, y_train)
'''

"\n# train the model\nfor state in source_state_list:\n    root_dir = '/shared/share_mala/llm-dro/income/'\n    trainx, trainy, valx, valy, testx, testy = load_data('concat', state, root_dir)\n    # concat all train data into one\n    if state == source_state_list[0]:\n        X_train = trainx\n        y_train = trainy\n    else:\n        X_train = np.concatenate((X_train, trainx), axis=0)\n        y_train = np.concatenate((y_train, trainy), axis=0)\n\n## train model\nmodel = MLPconcatClassifier(input_dim=4096, num_classes = 2, hidden_dim=64, refit_method='pca', initial_embedding_method='wiki')\nmodel.train_epochs = 100\nmodel.fit(X_train, y_train)\n"

In [38]:
concat_model = MLPconcatClassifier(input_dim=4096, num_classes = 2, hidden_dim=64, refit_method='pca', initial_embedding_method='wiki')

task_name = 'income'
source_state_str = '-'.join(source_state_list)
embedding_method = 'concat'
initial_embedding_method = 'wiki'
refit_method = 'pca'
save_dir = '/shared/share_mala/llm-dro/'
model_dir = f'{save_dir}/save_models/{task_name}/{source_state_str}/{embedding_method}/{initial_embedding_method}/{refit_method}/'    

concat_model.load(31, model_dir)

In [39]:
source_train_acc_dict, source_val_acc_dict, source_test_acc_dict = dict(), dict(), dict()
source_train_f1_dict, source_val_f1_dict, source_test_f1_dict = dict(), dict(), dict()

## report training, val, testing performance for each state
for state in source_state_list:
    root_dir = '/shared/share_mala/llm-dro/income/'
    trainx, trainy, valx, valy, testx, testy = load_data('concat', state, root_dir)
    train_acc, train_f1 = concat_model.score(trainx, trainy)
    val_acc, val_f1 = concat_model.score(valx, valy)
    test_acc, test_f1 = concat_model.score(testx, testy)
    
    source_train_acc_dict[state], source_val_acc_dict[state], source_test_acc_dict[state] = train_acc, val_acc, test_acc
    source_train_f1_dict[state], source_val_f1_dict[state], source_test_f1_dict[state] = train_f1, val_f1, test_f1

print("Source State: ")
print(f"average train acc: {np.mean(list(source_train_acc_dict.values())):.3f}, average train f1: {np.mean(list(source_train_f1_dict.values())):.3f}")
print(f"average val acc: {np.mean(list(source_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(source_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(source_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(source_test_f1_dict.values())):.3f}")


Source State: 
average train acc: 0.794, average train f1: 0.782
average val acc: 0.808, average val f1: 0.800
average test acc: 0.792, average test f1: 0.786


In [40]:
target_train_acc_dict, target_val_acc_dict, target_test_acc_dict = dict(), dict(), dict()
target_train_f1_dict, target_val_f1_dict, target_test_f1_dict = dict(), dict(), dict()

## report training, val, testing performance for each state
for state in ['SD']:
    if state not in source_state_list:
        root_dir = '/shared/share_mala/llm-dro/income/'
        trainx, trainy, valx, valy, testx, testy = load_data('concat', state, root_dir)
        #train_acc, train_f1 = baseline_model.score(trainx, trainy)
        val_acc, val_f1 = concat_model.score(valx, valy)
        test_acc, test_f1 = concat_model.score(testx, testy)
        
        target_val_acc_dict[state], target_test_acc_dict[state] = val_acc, test_acc
        target_val_f1_dict[state], target_test_f1_dict[state] = val_f1, test_f1

print("target State: ")
print(f"average val acc: {np.mean(list(target_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(target_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(target_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(target_test_f1_dict.values())):.3f}")


target State: 
average val acc: 0.729, average val f1: 0.692
average test acc: 0.714, average test f1: 0.683


#### target states

In [41]:
target_state = 'SD'
root_dir = '/shared/share_mala/llm-dro/income/'
trainx, trainy, valx, valy, testx, testy = load_data('concat', target_state, root_dir, num_val=100)
# concat all train data into one
refitX = valx
refity = valy

In [44]:
concat_model.load(31, model_dir)
concat_model.model.embedding.coefficients
concat_model.refit_epochs = 500
concat_model.refit_lr = 0.1
concat_model.refit(refitX, refity)

100%|██████████| 500/500 [03:05<00:00,  2.70it/s]


In [45]:
refit_target_val_acc_dict, refit_target_test_acc_dict =  dict(), dict()
refit_target_val_f1_dict, refit_target_test_f1_dict =  dict(), dict()

# record the performance of the target states
for state in ['SD']:
    if state not in source_state_list:
        root_dir = '/shared/share_mala/llm-dro/income/'
        trainx, trainy, valx, valy, testx, testy = load_data('concat', state, root_dir)
        
        val_acc, val_f1 = concat_model.score(valx, valy)
        test_acc, test_f1 = concat_model.score(testx, testy)
        refit_target_val_acc_dict[state], refit_target_test_acc_dict[state] =  val_acc, test_acc
        refit_target_val_f1_dict[state], refit_target_test_f1_dict[state] = val_f1, test_f1

print("refit target State: ")
print(f"average val acc: {np.mean(list(refit_target_val_acc_dict.values())):.3f}, average val f1: {np.mean(list(refit_target_val_f1_dict.values())):.3f}")
print(f"average test acc: {np.mean(list(refit_target_test_acc_dict.values())):.3f}, average test f1: {np.mean(list(refit_target_test_f1_dict.values())):.3f}")


refit target State: 
average val acc: 0.733, average val f1: 0.693
average test acc: 0.718, average test f1: 0.685


### test refit func

In [16]:
from refit import * 

def refit(refitx, refity, test_dict, args):
    '''
    refit and test models on target states
    '''
    # load args
    task_name, source_state, num_list, year, embedding_method, prompt_method, initial_embedding_method, refit_method, model_name, seed, experiment_id, refit_id, target_state_list, is_regression, gpu_id = args
    source_state_str = "-".join(source_state)
    # set up gpu
    if 'mlp' in model_name:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        device = torch.device(f'cuda:{gpu_id}')
        torch.cuda.set_device(device)
    # set up save dir and path
    save_dir = '/shared/share_mala/llm-dro/'
    if embedding_method == 'concat':
        model_dir = f'{save_dir}/save_models/{task_name}/{source_state_str}/{embedding_method}/{initial_embedding_method}/{refit_method}/'    
    else:
        raise NotImplementedError

    if embedding_method == 'concat':
        os.makedirs(f'{save_dir}/refit_results/{task_name}/{embedding_method}/{initial_embedding_method}/{refit_method}/{source_state_str}/{model_name}', exist_ok=True)            
        path = f'{save_dir}/refit_results/{task_name}/{embedding_method}/{initial_embedding_method}/{refit_method}/{source_state_str}/{model_name}/{experiment_id}_{refit_id}.json'
    else:
        raise NotImplementedError
    print(f"Refit {task_name}-{source_state_str}-{model_name}-ID {experiment_id}-Refit ID {refit_id} begins")

    # check if the experiment has been done
    #if os.path.exists(path):
    #    return 
    
    # save hyperparamters
    result_record = {}    
    result_record["model"] = model_name
    result_record["source_state"] = source_state_str
    result_record["year"] = year
    result_record["embedding"] = embedding_method
    if embedding_method != 'one_hot':
        result_record["prompt"] = prompt_method
    
    # load trained model and hyperparameters
    model = fetch_model('mlp_concat', is_regression, refitx.shape[1]-1, initial_embedding_method=initial_embedding_method, refit_method=refit_method)
    config = sample_config(f'mlp_concat_{refit_method}', seed, experiment_id)
    if 'mlp' in model_name:
        config["device"] = gpu_id
    result_record["config"] = config 
    result_record['initial_embedding_method'] = initial_embedding_method
    result_record['refit_method'] = refit_method   
    try: 
        model.load(experiment_id, model_dir)
    except:
        raise ValueError(f"Model {model_name}_{experiment_id} not found in {model_dir}")
    # load refit hyperparameters
    refit_config = sample_config('refit_mlp_concat', seed, refit_id)
    result_record['refit_config'] = refit_config
    model.update_refit_config(refit_config)
    print(refit_config)

    # refit model
    model.fit_embeddings(refitx, refity)

    # model testing
    test_result_acc = {}
    test_result_f1 = {}
    for target_state in target_state_list:
        if target_state in source_state: # do not load if source state
            continue
        else:
            testx, testy = test_dict[target_state]
            # save accuracy and f1 score
            acc, f1 = model.score(testx, testy)
            test_result_acc[target_state] = acc 
            test_result_f1[target_state] = f1
    # save test results
    result_record["test_result_acc"] = test_result_acc
    result_record["test_result_f1"] = test_result_f1
    if 'mlp' in model_name:
        result_record["config"]["device"] = gpu_id
    
    # save result
    #with open(path, 'w') as f:
    #    json.dump(result_record, f)
    del model 
    torch.cuda.empty_cache()
    gc.collect() 
    print(f"Experiment {task_name}-{source_state_str}-{model_name}-ID {experiment_id}-Refit ID {refit_id} finished!!")
    return result_record

In [17]:
# setup args
task = 'income'
source = ['CA', 'TX', 'FL', 'NY', 'PA']
num = [5000, 5000, 5000, 5000, 5000]
year = 2018
embedding = 'concat'
prompt = None
initial_embedding_method = 'wiki'
refit_method = 'pca'
model = 'mlp'
experiment_id = 31
refit_id = 14
num_gpus = torch.cuda.device_count()

arg = task, source, num, 2018, embedding, prompt, initial_embedding_method, refit_method, model, 0, experiment_id, refit_id, ALL_STATES, 0, experiment_id%num_gpus

In [18]:
# load validation and test data
valx, valy, test_dict = load_val_test_data(arg)
print(valx.shape, valy.shape)

(23552, 4097) (23552,)


In [19]:
### refit model
result_record = refit(valx, valy, test_dict, arg)

Refit income-CA-TX-FL-NY-PA-mlp-ID 31-Refit ID 14 begins
{'refit_lr': 0.001, 'refit_epochs': 200, 'refit_num': 512}


100%|██████████| 200/200 [04:08<00:00,  1.24s/it]


Experiment income-CA-TX-FL-NY-PA-mlp-ID 31-Refit ID 14 finished!!


In [20]:
print("refit target State: ")
print(f"average test acc: {np.mean(list(result_record['test_result_acc'].values())):.3f}, average test f1: {np.mean(list(result_record['test_result_f1'].values())):.3f}")


refit target State: 
average test acc: 0.805, average test f1: 0.774


In [1]:
from train import *
from refit import *

In [21]:
embedding_method = 'one_hot'
prompt_method = None
target_state_list = ['WY', 'PR']
year = 2018
seed = 0

source_state_list = ['CA']
task_name = 'mobility'
save_dir = '/shared/share_mala/llm-dro/'
refit_dict = {}   # validation data
test_dict = {}
refit_num = 1024
# load validation and test data
for idx, state in enumerate(target_state_list):
    if state in source_state_list: # do not load if source state
        continue
    else: # load validation/test data if target state
        X, y = get_raw_data(task_name, embedding_method, prompt_method, 
                            state, save_dir, year)
        # check if refit num is larger than the data size
        if refit_num > X.shape[0]:
            cur_refit_num = X.shape[0] // 2
        else:
            cur_refit_num = refit_num
        # sample training/validation data
        valx, valy, testx, testy = sample_val_test_data(X, y, val_num = cur_refit_num, seed=seed)  
        refit_dict[state] = [valx, valy]
        test_dict[state] = [testx, testy]


In [23]:
test_dict['WY'][0].shape

(509, 63)

In [20]:
X.shape[0]

4730