In [2]:
import sys
sys.path.append('/vol/tensusers4/nhollain/thesis2023-2024/s_clip_scripts')
sys.path.append('./scripts')
sys.path.append('./probvlm')

import itertools
from main import main, format_checkpoint
from params import parse_args
import copy
import os
from tqdm import tqdm
from tuning_tools import prep_str_args, evaluate_checkpoint    

In [3]:
%%time
# To evaluate CLIP without fine-tuning it! Set to False to skip
evaluate_baseline = False
if evaluate_baseline:
    evaluate_checkpoint(checkpoint_path = None, epoch = 0, split = 'test')

CPU times: total: 0 ns
Wall time: 0 ns


In [4]:
from collections import Counter

results = []
if os.path.exists('./eval.txt'):   
    with open('./eval.txt', 'r') as f:
        results = f.readlines()

# Remove any non-result lines from the eval file, and split the lines on the tab character
# (results have format: model_name\tdataset_name\tmetric_name\tmetric_value)
results = [r.replace('\n','').split('\t')[0] for r in results if '\t' in r]
model_names = results
# Remove the timestamp from the model names, as well as the specific fold - rest of the name contains params
model_names = ['-'.join(m.split('-')[2:]).split('-fold')[0] for m in model_names]
model_names = dict(Counter(model_names))
# Show which models are already present in eval.txt
# model_names

# Grid Search

In [5]:
# Do a grid search on the parameters
# NOTE: for active learning, save-freq should be set to 1
base_str_args = ''' --train-data ILT
--val-data ILT 
--zeroshot-frequency 5  
--save-freq 1
'''
# --label-ratio 0.1
# --active-learning
# --keyword-path keywords/RS/class-name.txt

# Dictionary of values to gridsearch for hyperparam tuning
gridsearch_dict = {
    '--epochs' : [25], #list(range(15,36,5)) if 'active-learning' in base_str_args else [35], #[10,15,20,25,30,35],
    '--batch-size' : [64],
#     '--al-iter': [1], #list(range(3,17,2)), #list(range(1,6,2)),
#     '--al-epochs': [35],
    '--label-ratio': [1.0], # 0.05, 0.1, 0.2, 0.4, 0.8,
#     '--pl-method': ['hard.text'],
}


split = 'test'
# The number of validation repetitions is very specifically chosen because we have 9 folds for the datasets!
num_repeats = 5 if 'test' else 9
num_evals = 20 # How many evaluations are done with evaluate(...) - KEEP THIS FIXED

gridsearch_values = list(gridsearch_dict.values())
gridsearch_keys = list(gridsearch_dict.keys())
configs = list(itertools.product(*gridsearch_values))
print('Number of configs:', len(configs))

Number of configs: 1


In [6]:
%%time
for c, config in enumerate(configs): # Gridsearch
    str_args = copy.deepcopy(base_str_args)
    # Add the gridsearch parameters to the arguments
    for i, param in enumerate(config):
        param_name = gridsearch_keys[i]
        str_args += '\n{} {}'.format(param_name, param)
        
    str_args = prep_str_args(str_args)
    print(str_args)
    args = parse_args(str_args)
    checkpoint_hypothetical = format_checkpoint(args)
    # Remove the timestamp from the hypothetical checkpoint, so we can compare to the params of other checkpoints
    checkpoint_params = '-'.join(checkpoint_hypothetical.split('-')[2:]).split('-fold')[0]

    # Check if we've already trained the exact same model, correct the number of training iterations we still need to do
    if checkpoint_params in model_names:
        # The number of times to repeat depends on how often the model's been evaluated already
        start_repeat = int(model_names[checkpoint_params]/num_evals)
    else: # If we've never trained + evaluated the model before, just use num_repeats
        start_repeat = 0
    print('Config number {}: {} repeats'.format(c, max(0,num_repeats-start_repeat)))
    for i in range(start_repeat, num_repeats):
        # print('repeat number', i) 
        args = parse_args(str_args)
        args.k_fold = i
        # We compute here for which epochs we need to evaluate (based on for which epochs we checkpoint)
        epoch_freq = args.save_freq
        epochs = list(range(epoch_freq,args.epochs+1,epoch_freq))
        # print('Epochs to checkpoint', epochs)
        # print('Args k fold (outside):' , args.k_fold)
        checkpoint_path = main(args) # Calls the main.py function of S-CLIP
        for epoch in epochs:
            evaluate(checkpoint_path, epoch = epoch, kfold = i, split = split)
        # Remove the checkpoint after evaluating, to save space
        os.system("rm -r {}".format(checkpoint_path))  

['--train-data', 'ILT', '--val-data', 'ILT', '--zeroshot-frequency', '5', '--save-freq', '1', '--epochs', '25', '--batch-size', '64', '--label-ratio', '1.0']
Config number 0: 5 repeats


 10%|███▋                                  | 25.2M/256M [00:06<00:56, 4.08MiB/s]


KeyboardInterrupt: 

# Training

In [None]:
fashion = False

if fashion:
    str_args = '''--train-data Fashion-ALL
            --label-ratio 0.1
            --val-data Fashion-ALL
            --keyword-path keywords/fashion/class-name.txt
            --epochs 10
            --method base  
    '''
else:
    str_args = ''' --train-data RS-ALL
            --label-ratio 0.1
            --val-data RS-ALL
            --imagenet-val RSICD-CLS \
            --keyword-path keywords/RS/class-name.txt
            --epochs 5
            --lr 5e-5
            --zeroshot-frequency 5  
            --method base
            --active-learning
            --al-iter 3
    '''
           # --active-learning
        # --al-iter 3

# Convert string arguments to a format that can be parsed by parse_args             
str_args = prep_str_args(str_args)
args = parse_args(str_args)

In [None]:
%%time
# checkpoint_path = main(args)

# Eval

In [None]:
%%time
# evaluate(checkpoint_path)