In [1]:
import sys
sys.path.append('/vol/tensusers5/nhollain/s_clip_scripts')

import itertools
from main import main, format_checkpoint
from params import parse_args
import copy

def prep_str_args(str_args): # Code to parse the string style arguments, as shown below
    str_args = str_args.split('\n') # Split on newline
    str_args = [s.strip() for s in str_args] # Remove any whitespaces from the start and end of the strings
    # Split on the space between the parameter name and the value, e.g. '--name x' becomes ['--name', 'x']
    str_args = [s.split(' ') for s in str_args] 
    str_args = list(itertools.chain(*str_args)) # Flatten the resulting list of lists
    str_args = [s for s in str_args if len(s) > 0] # Remove arguments that are empty
    return str_args
    
def evaluate(checkpoint):
    checkpoint = checkpoint_path 
    if 'Fashion' in checkpoint:
        zeroshot_datasets = ["Fashion200k-SUBCLS", "Fashion200k-CLS", "FashionGen-CLS", "FashionGen-SUBCLS", "Polyvore-CLS", ]
        retrieval_datasets = ["FashionGen", "Polyvore", "Fashion200k",]
    else:
        zeroshot_datasets = ["RSICD-CLS", "UCM-CLS"] # "WHU-RS19", "RSSCN7", "AID" -> NOT WORKING bc of different data-loading workings
        retrieval_datasets = ["RSICD", "UCM", "Sydney"]
    
    for dataset in zeroshot_datasets:
        str_args = ['--name', checkpoint, '--imagenet-val', dataset]
        args = parse_args(str_args)
        main(args)
    
    for dataset in retrieval_datasets:
        str_args = ['--name', checkpoint, '--val-data', dataset]
        args = parse_args(str_args)
        main(args)

In [2]:
from collections import Counter
with open('./eval.txt', 'r') as f:
    results = f.readlines()

# Remove any non-result lines from the eval file, and split the lines on the tab character
# (results have format: model_name\tdataset_name\tmetric_name\tmetric_value)
results = [r.replace('\n','').split('\t')[0] for r in results if '\t' in r]
model_names = results
# Remove the timestamp from the model names, as well as the specific fold - rest of the name contains params
model_names = ['-'.join(m.split('-')[2:]).split('-fold')[0] for m in model_names]
model_names = dict(Counter(model_names))
model_names

{'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.0005-bs_64': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.0005-bs_128': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.00005-bs_64': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.00005-bs_128': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.000005-bs_64': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.000005-bs_128': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_15-lr_0.0005-bs_64': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_15-lr_0.0005-bs_128': 180,
 'data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_

# Grid Search

In [None]:
# %%time
# Do a grid search on the parameters
base_str_args = ''' --train-data RS-ALL
--label-ratio 0.1
--val-data RS-ALL
--imagenet-val RSICD-CLS 
--keyword-path keywords/RS/class-name.txt
--zeroshot-frequency 5  
--method base
'''

# Dictionary of values to gridsearch for hyperparam tuning
gridsearch_dict = {
    '--epochs' : [5,10,15,20,25,30,35],
    '--lr' : [5e-5,], # 5e-4, 5e-6
    '--batch-size' : [64] #,128,256],
}
num_repeats = 9
num_evals = 20 # How many evaluations are done with evaluate(...)

gridsearch_values = list(gridsearch_dict.values())
gridsearch_keys = list(gridsearch_dict.keys())
configs = list(itertools.product(*gridsearch_values))
print('Number of configs:', len(configs))

for c, config in enumerate(configs): # Gridsearch
    str_args = copy.deepcopy(base_str_args)
    for i, param in enumerate(config):
        param_name = gridsearch_keys[i]
        str_args += '\n{} {}'.format(param_name, param)
        
    str_args = prep_str_args(str_args)
    print(str_args)
    args = parse_args(str_args)
    checkpoint_hypothetical = format_checkpoint(args)
    # Remove the timestamp from the hypothetical checkpoint, so we can compare to the params of other checkpoints
    checkpoint_params = '-'.join(checkpoint_hypothetical.split('-')[2:]).split('-fold')[0]
    print(checkpoint_params)
    # Check if we've already trained the exact same model, correct the number of training iterations we still need to do
    if checkpoint_params in model_names:
        # The number of times to repeat depends on how often the model's been evaluated already
        start_repeat = int(model_names[checkpoint_params]/num_evals)
    else: # If we've never trained + evaluated the model before, just use num_repeats
        start_repeat = 0
    print('Config number {}: {} repeats'.format(c, num_repeats-start_repeat))
    for i in range(start_repeat, num_repeats):
        print('repeat number', i) 
        args = parse_args(str_args)
        args.k_fold = i
        print('Args k fold (outside):' , args.k_fold)
        checkpoint_path = main(args) # Calls the main.py function of S-CLIP
        evaluate(checkpoint_path)

Number of configs: 7
['--train-data', 'RS-ALL', '--label-ratio', '0.1', '--val-data', 'RS-ALL', '--imagenet-val', 'RSICD-CLS', '--keyword-path', 'keywords/RS/class-name.txt', '--zeroshot-frequency', '5', '--method', 'base', '--epochs', '5', '--lr', '5e-05', '--batch-size', '64']
data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_5-lr_0.00005-bs_64
Config number 0: 0 repeats
['--train-data', 'RS-ALL', '--label-ratio', '0.1', '--val-data', 'RS-ALL', '--imagenet-val', 'RSICD-CLS', '--keyword-path', 'keywords/RS/class-name.txt', '--zeroshot-frequency', '5', '--method', 'base', '--epochs', '10', '--lr', '5e-05', '--batch-size', '64']
data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64
Config number 1: 9 repeats
repeat number 0
Args k fold (outside): 0
formatting...
kfold: 0
Log path: ./checkpoint/2023_11_07-16_49_19-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [03:05<00:00, 18.54s/it]


=> resuming checkpoint './checkpoint/2023_11_07-16_49_19-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_0/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
RSICD-CLS (split: val)
CLS size: 1094
=> resuming checkpoint './checkpoint/2023_11_07-16_49_19-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_0/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
UCM-CLS (split: val)
CLS size: 210
=> resuming checkpoint './checkpoint/2023_11_07-16_49_19-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_0/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
RSICD (split: val)
=> resuming checkpoint './checkpoint/2023_11_07-16_49_19-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_0/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:53<00:00, 17.32s/it]


=> resuming checkpoint './checkpoint/2023_11_07-16_53_01-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_1/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
RSICD-CLS (split: val)
CLS size: 1094
=> resuming checkpoint './checkpoint/2023_11_07-16_53_01-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_1/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
UCM-CLS (split: val)
CLS size: 210
=> resuming checkpoint './checkpoint/2023_11_07-16_53_01-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_1/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
RSICD (split: val)
=> resuming checkpoint './checkpoint/2023_11_07-16_53_01-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_1/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:54<00:00, 17.48s/it]


=> resuming checkpoint './checkpoint/2023_11_07-16_56_28-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_2/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
RSICD-CLS (split: val)
CLS size: 1094
=> resuming checkpoint './checkpoint/2023_11_07-16_56_28-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_2/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
UCM-CLS (split: val)
CLS size: 210
=> resuming checkpoint './checkpoint/2023_11_07-16_56_28-data_RS-ALL-ratio_0.1-model_RN50-method_base-kw_none-AL_False-PL_None-vit_False-epochs_10-lr_0.00005-bs_64-fold_2/checkpoints/epoch_latest.pt' (epoch 10)
Getting data...
RSICD (split: val)


# Training

In [4]:
fashion = False

if fashion:
    str_args = '''--train-data Fashion-ALL
            --label-ratio 0.1
            --val-data Fashion-ALL
            --keyword-path keywords/fashion/class-name.txt
            --epochs 10
            --method base  
    '''
else:
    str_args = ''' --train-data RS-ALL
            --label-ratio 0.1
            --val-data RS-ALL
            --imagenet-val RSICD-CLS \
            --keyword-path keywords/RS/class-name.txt
            --epochs 25
            --lr 5e-5
            --zeroshot-frequency 5  
            --method base
    '''
           # --active-learning
        # --al-iter 3

# Convert string arguments to a format that can be parsed by parse_args             
str_args = prep_str_args(str_args)
args = parse_args(str_args)

In [None]:
%%time
checkpoint_path = main(args)

# Eval

In [None]:
%%time
evaluate(checkpoint_path)