In [1]:
import sys
sys.path.append('/vol/tensusers4/nhollain/thesis2023-2024/s_clip_scripts')

import itertools
from main import main, format_checkpoint
from params import parse_args
import copy
import os
from tqdm import tqdm

def prep_str_args(str_args): # Code to parse the string style arguments, as shown below
    str_args = str_args.split('\n') # Split on newline
    str_args = [s.strip() for s in str_args] # Remove any whitespaces from the start and end of the strings
    # Split on the space between the parameter name and the value, e.g. '--name x' becomes ['--name', 'x']
    str_args = [s.split(' ') for s in str_args] 
    str_args = list(itertools.chain(*str_args)) # Flatten the resulting list of lists
    str_args = [s for s in str_args if len(s) > 0] # Remove arguments that are empty
    return str_args
    
def evaluate(checkpoint, epoch = None, kfold = -1):
    # print('=> Resuming checkpoint {} (epoch {})'.format(checkpoint, epoch))
    checkpoint = checkpoint_path 
    if 'Fashion' in checkpoint:
        zeroshot_datasets = ["Fashion200k-SUBCLS", "Fashion200k-CLS", "FashionGen-CLS", "FashionGen-SUBCLS", "Polyvore-CLS", ]
        retrieval_datasets = ["FashionGen", "Polyvore", "Fashion200k",]
    else:
        zeroshot_datasets = ["RSICD-CLS", "UCM-CLS"] # "WHU-RS19", "RSSCN7", "AID" -> NOT WORKING bc of different data-loading workings
        retrieval_datasets = ["RSICD", "UCM", "Sydney"]
    
    for dataset in zeroshot_datasets:
        str_args = ['--name', checkpoint, '--imagenet-val', dataset, '--resume-epoch', str(epoch), '--k-fold', str(kfold)]
        args = parse_args(str_args)
        main(args)
    
    for dataset in retrieval_datasets:
        str_args = ['--name', checkpoint, '--val-data', dataset, '--resume-epoch', str(epoch), '--k-fold', str(kfold)]
        args = parse_args(str_args)
        main(args)
    

[nltk_data] Downloading package punkt to /home/nhollain/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/nhollain/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [2]:
from collections import Counter

results = []
if os.path.exists('./eval.txt'):   
    with open('./eval.txt', 'r') as f:
        results = f.readlines()

# Remove any non-result lines from the eval file, and split the lines on the tab character
# (results have format: model_name\tdataset_name\tmetric_name\tmetric_value)
results = [r.replace('\n','').split('\t')[0] for r in results if '\t' in r]
model_names = results
# Remove the timestamp from the model names, as well as the specific fold - rest of the name contains params
model_names = ['-'.join(m.split('-')[2:]).split('-fold')[0] for m in model_names]
model_names = dict(Counter(model_names))
# Show which models are already present in eval.txt
# model_names

# Grid Search

In [3]:
# Do a grid search on the parameters
# NOTE: for active learning, save-freq should be set to 1
base_str_args = ''' --train-data RS-ALL
--val-data RS-ALL
--imagenet-val RSICD-CLS 
--keyword-path keywords/RS/class-name.txt
--zeroshot-frequency 5  
--method base
--save-freq 5
'''
# --label-ratio 0.1
# --active-learning

# Dictionary of values to gridsearch for hyperparam tuning
gridsearch_dict = {
    '--epochs' : [35], #list(range(15,36,5)) if 'active-learning' in base_str_args else [35], #[10,15,20,25,30,35],
    '--lr' : [5e-5, 5e-4, 5e-6],
    '--batch-size' : [128],
    '--al-iter': [3],#list(range(3,17,2)), #list(range(1,6,2)),
    '--al-epochs': [15],
    '--label-ratio': [0.1] #, 0.3, 0.5, 0.7]
}

# This number is very specifically chosen because we have 9 folds for the datasets!
num_repeats = 9
num_evals = 20 # How many evaluations are done with evaluate(...) - KEEP THIS FIXED

gridsearch_values = list(gridsearch_dict.values())
gridsearch_keys = list(gridsearch_dict.keys())
configs = list(itertools.product(*gridsearch_values))
print('Number of configs:', len(configs))

Number of configs: 3


In [4]:
%%time
for c, config in enumerate(configs): # Gridsearch
    str_args = copy.deepcopy(base_str_args)
    # Add the gridsearch parameters to the arguments
    for i, param in enumerate(config):
        param_name = gridsearch_keys[i]
        str_args += '\n{} {}'.format(param_name, param)
        
    str_args = prep_str_args(str_args)
    print(str_args)
    args = parse_args(str_args)
    checkpoint_hypothetical = format_checkpoint(args)
    # Remove the timestamp from the hypothetical checkpoint, so we can compare to the params of other checkpoints
    checkpoint_params = '-'.join(checkpoint_hypothetical.split('-')[2:]).split('-fold')[0]

    # Check if we've already trained the exact same model, correct the number of training iterations we still need to do
    if checkpoint_params in model_names:
        # The number of times to repeat depends on how often the model's been evaluated already
        start_repeat = int(model_names[checkpoint_params]/num_evals)
    else: # If we've never trained + evaluated the model before, just use num_repeats
        start_repeat = 0
    print('Config number {}: {} repeats'.format(c, max(0,num_repeats-start_repeat)))
    for i in range(start_repeat, num_repeats):
        # print('repeat number', i) 
        args = parse_args(str_args)
        args.k_fold = i
        # We compute here for which epochs we need to evaluate (based on for which epochs we checkpoint)
        epoch_freq = args.save_freq
        epochs = list(range(epoch_freq,args.epochs+1,epoch_freq))
        # print('Epochs to checkpoint', epochs)
        # print('Args k fold (outside):' , args.k_fold)
        checkpoint_path = main(args) # Calls the main.py function of S-CLIP
        for epoch in epochs:
            evaluate(checkpoint_path, epoch = epoch, kfold = i)
        # Remove the checkpoint after evaluating, to save space
        os.system("rm -r {}".format(checkpoint_path))  

['--train-data', 'RS-ALL', '--val-data', 'RS-ALL', '--imagenet-val', 'RSICD-CLS', '--keyword-path', 'keywords/RS/class-name.txt', '--zeroshot-frequency', '5', '--method', 'base', '--save-freq', '5', '--epochs', '35', '--lr', '5e-05', '--batch-size', '128', '--al-iter', '3', '--al-epochs', '15', '--label-ratio', '0.1']
Config number 0: 0 repeats
['--train-data', 'RS-ALL', '--val-data', 'RS-ALL', '--imagenet-val', 'RSICD-CLS', '--keyword-path', 'keywords/RS/class-name.txt', '--zeroshot-frequency', '5', '--method', 'base', '--save-freq', '5', '--epochs', '35', '--lr', '0.0005', '--batch-size', '128', '--al-iter', '3', '--al-epochs', '15', '--label-ratio', '0.1']
Config number 1: 9 repeats


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [06:11<00:00, 10.60s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:54<00:00, 10.12s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:39<00:00,  9.70s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:53<00:00, 10.11s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:58<00:00, 10.24s/it]
100%|███████████████████████████████████████████████████████

['--train-data', 'RS-ALL', '--val-data', 'RS-ALL', '--imagenet-val', 'RSICD-CLS', '--keyword-path', 'keywords/RS/class-name.txt', '--zeroshot-frequency', '5', '--method', 'base', '--save-freq', '5', '--epochs', '35', '--lr', '5e-06', '--batch-size', '128', '--al-iter', '3', '--al-epochs', '15', '--label-ratio', '0.1']
Config number 2: 9 repeats


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:47<00:00,  9.92s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:46<00:00,  9.91s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:48<00:00,  9.95s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:53<00:00, 10.09s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [05:44<00:00,  9.86s/it]
100%|███████████████████████████████████████████████████████

CPU times: user 1d 11h 5min 12s, sys: 36min 22s, total: 1d 11h 41min 34s
Wall time: 2h 55min 46s


# Training

In [None]:
fashion = False

if fashion:
    str_args = '''--train-data Fashion-ALL
            --label-ratio 0.1
            --val-data Fashion-ALL
            --keyword-path keywords/fashion/class-name.txt
            --epochs 10
            --method base  
    '''
else:
    str_args = ''' --train-data RS-ALL
            --label-ratio 0.1
            --val-data RS-ALL
            --imagenet-val RSICD-CLS \
            --keyword-path keywords/RS/class-name.txt
            --epochs 5
            --lr 5e-5
            --zeroshot-frequency 5  
            --method base
            --active-learning
            --al-iter 3
    '''
           # --active-learning
        # --al-iter 3

# Convert string arguments to a format that can be parsed by parse_args             
str_args = prep_str_args(str_args)
args = parse_args(str_args)

In [None]:
%%time
# checkpoint_path = main(args)

# Eval

In [None]:
%%time
# evaluate(checkpoint_path)