In [1]:
import glob
import os
import json

from icecream import ic
import numpy as np
from mmcv import Config
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm  # Progress bar
from sklearn.metrics import confusion_matrix
import wandb

from functions import create_train_validation_and_test_scene_list, get_model, get_loss, class_decider, compute_metrics, load_model
from loaders import get_variable_options, AI4ArcticChallengeTestDataset, AI4ArcticChallengeDataset

torch.set_num_threads(10)

def load_model_cpu(net, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    net.load_state_dict(checkpoint['model_state_dict'])
    return net

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
workdirs_dir = '/Data/sim/antonk/post74_models'
workdirs2config = {
    'workdir_01a': ['01a', 'SOD'],
    'workdir_01b': ['01b', 'SOD'],
    'workdir_01c': ['01c', 'SOD'],
    'workdir_01d': ['01d', 'SOD'],
    'workdir_01e': ['01e', 'SOD'],
    'workdir_02a': ['02a', 'SIR'],
    'workdir_02b': ['02b', 'SIR'],
    'workdir_02c': ['02c', 'SIR'],
    'workdir_02c1': ['02c', 'SIR'],
    'workdir_02c2': ['02c2', 'SIR'],
    'workdir_02c3': ['02c3', 'SIR'],
    'workdir_02c4': ['02c4', 'SIR'],
    'workdir_02d': ['02d', 'SIR'],
    'workdir_02e': ['02e', 'SIR'],
    'workdir_02f': ['02f', 'SIR'],
    'workdir_02h': ['02h', 'SIR'],
    'workdir_02i': ['02i', 'SIR'],
    'workdir_02j': ['02j', 'SIR'],
    'workdir_02k': ['02k', 'SIR'],
    'workdir_02l': ['02l', 'SIR'],
    'workdir_02m': ['02m', 'SIR'],
    'workdir_02n': ['02n', 'SIR'],
    'workdir_02o': ['02o', 'SIR'],
    'workdir_04c1': ['04c', 'SIR'],
    'workdir_04c2': ['04c2', 'SIR'],
    'workdir_04o1': ['04o', 'SIR'],
    'workdir_04o2': ['04o2', 'SIR'],
}
workdirs = sorted(list(workdirs2config.keys()))


In [10]:
api = wandb.Api(api_key='e0912e764420c974191dc9a23c71fed3d683b2b9')
entity = 'korosov-nersc'
api_projects = api.projects(f'{entity}')

In [11]:
for project in tqdm(api_projects, total=len(api_projects.objects), desc="WANDB"):
    project = project.name
    runs = api.runs(f'{entity}/{project}')
    for run in runs:
        pth_file = None
        for arg in run.metadata['args']:
            workdir = None
            if 'workdir' in arg:
                workdir = arg
                break
        if workdir is not None and workdir in workdirs:
            history = run.history()
            if len(history) == 0:
                continue
            config_suffix = workdirs2config[workdir][0]
            search_key = workdirs2config[workdir][1]

            args_config = f'configs/sic_mse/sic_mse_maud_{config_suffix}.py'
            for column in history.columns:
                if search_key in column and 'f1_metric' in column:
                    max_score_step = int(round(history['_step'][history[column].idxmax()]/10)*10)
                    pth_file = f'{workdirs_dir}/{workdir}/best_model_{workdir}_00{max_score_step:03}.pth'
                    if not os.path.exists(pth_file):
                        max_score_step -= 10
                        pth_file = f'{workdirs_dir}/{workdir}/best_model_{workdir}_00{max_score_step:03}.pth'
                    if not os.path.exists(pth_file):
                        print(f'FileNotFoundError {pth_file} not found')
        if pth_file is None:
            continue
        # load data and compute conf matrix
        ofile = 'conf_matrs_png/' + os.path.basename(pth_file).replace('.pth', '.png')
        ic(args_config)
        cfg = Config.fromfile(args_config)
        train_options = cfg.train_options
        # Get options for variables, amsrenv grid, cropping and upsampling.
        train_options = get_variable_options(train_options)
        device = 'cpu'
        net = get_model(train_options, device)
        net = load_model_cpu(net, pth_file)
        create_train_validation_and_test_scene_list(train_options)
        loss_ce_functions = {chart: get_loss(train_options['chart_loss'][chart]['type'], chart=chart, **train_options['chart_loss'][chart])
                                for chart in train_options['charts']}
        sir_name = train_options['charts'][0]
        dataset_val = AI4ArcticChallengeTestDataset(options=train_options, files=train_options['validate_list'][::5], mode='train')
        dataloader_val = torch.utils.data.DataLoader(
            dataset_val, batch_size=None, num_workers=train_options['num_workers_val'], shuffle=False)
        # - Stores the output and the reference pixels to calculate the scores after inference on all the scenes.
        outputs_flat = {chart: torch.Tensor().to(device) for chart in train_options['charts']}
        inf_ys_flat = {chart: torch.Tensor().to(device) for chart in train_options['charts']}
        for i, (inf_x, inf_y, cfv_masks, tfv_mask, name, original_size) in enumerate(tqdm(iterable=dataloader_val,
                                                                            total=len(train_options['validate_list']),
                                                                            colour='green')):
            with torch.no_grad():
                inf_x = inf_x.to(device, non_blocking=True)
                output = net(inf_x)
            for chart in train_options['charts']:
                output[chart] = class_decider(output[chart], train_options, chart)
                outputs_flat[chart] = torch.cat((outputs_flat[chart], output[chart][~cfv_masks[chart]]))
                #outputs_tfv_mask[chart] = torch.cat((outputs_tfv_mask[chart], output[chart][~tfv_mask]))
                inf_ys_flat[chart] = torch.cat((inf_ys_flat[chart], inf_y[chart][~cfv_masks[chart]].to(device, non_blocking=True)))
        cm = confusion_matrix(inf_ys_flat[sir_name], outputs_flat[sir_name])

        plt.figure(figsize=(3, 3))
        plt.imshow(cm, interpolation='nearest', cmap='Blues')
        plt.colorbar(shrink=0.5)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.tight_layout()
        plt.savefig(ofile, pad_inches=0.1)
        plt.close()


WANDB:   0%|          | 0/28 [00:00<?, ?it/s]ic| args_config: 'configs/sic_mse/sic_mse_maud_04o.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 8/8 [00:10<00:00,  1.31s/it]
 21%|[32m██        [0m| 8/38 [01:15<04:41,  9.40s/it]
ic| args_config: 'configs/sic_mse/sic_mse_maud_04o2.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 8/8 [00:14<00:00,  1.79s/it]
 21%|[32m██        [0m| 8/38 [01:18<04:55,  9.84s/it]
WANDB:  11%|█         | 3/28 [03:01<25:14, 60.58s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_04c.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 8/8 [00:09<00:00,  1.20s/it]
