In [1]:
import glob
import os
import json

from icecream import ic
import numpy as np
from mmcv import Config
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm  # Progress bar
from sklearn.metrics import confusion_matrix
import wandb
from torchmetrics.functional import f1_score, accuracy, precision, recall

from functions import create_train_validation_and_test_scene_list, get_model, get_loss, class_decider, compute_metrics, load_model
from loaders import get_variable_options, AI4ArcticChallengeTestDataset, AI4ArcticChallengeDataset

torch.set_num_threads(10)

metric_funcs = [f1_score, accuracy, precision, recall]

def load_model_cpu(net, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    net.load_state_dict(checkpoint['model_state_dict'])
    return net

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
workdirs_dir = '/Data/sim/antonk/post74_models'

workdirs2config = {
    #'workdir_01a': ['01a', 'SOD'],
    'workdir_01b': ['01b', 'SOD'],
    'workdir_01c': ['01c', 'SOD'],
    'workdir_01d': ['01d', 'SOD'],
    'workdir_01e': ['01e', 'SOD'],
    'workdir_01f': ['01f', 'SOD'],
    'workdir_01g': ['01g', 'SOD'],
    #'workdir_02a': ['02a', 'SIR'],
    #'workdir_02b': ['02b', 'SIR'],
    #'workdir_02c': ['02c', 'SIR'],
    #'workdir_02c1': ['02c', 'SIR'],
    #'workdir_02c2': ['02c2', 'SIR'],
    #'workdir_02c3': ['02c3', 'SIR'],
    #'workdir_02c4': ['02c4', 'SIR'],
    #'workdir_02d': ['02d', 'SIR'],
    #'workdir_02e': ['02e', 'SIR'],
    #'workdir_02f': ['02f', 'SIR'],
    #'workdir_02h': ['02h', 'SIR'],
    #'workdir_02i': ['02i', 'SIR'],
    #'workdir_02j': ['02j', 'SIR'],
    #'workdir_02k': ['02k', 'SIR'],
    #'workdir_02l': ['02l', 'SIR'],
    #'workdir_02m': ['02m', 'SIR'],
    #'workdir_02n': ['02n', 'SIR'],
    #'workdir_02o': ['02o', 'SIR'],
    #'workdir_04c1': ['04c', 'SIR'],
    #'workdir_04c2': ['04c2', 'SIR'],
    #'workdir_04o1': ['04o', 'SIR'],
    #'workdir_04o2': ['04o2', 'SIR'],
}
workdirs = sorted(list(workdirs2config.keys()))


In [3]:
api = wandb.Api(api_key='e0912e764420c974191dc9a23c71fed3d683b2b9')
entity = 'korosov-nersc'
api_projects = api.projects(f'{entity}')

In [4]:
for project in tqdm(api_projects, total=len(api_projects.objects), desc="WANDB"):
    project = project.name
    runs = api.runs(f'{entity}/{project}')
    for run in runs:
        pth_file = None
        for arg in run.metadata['args']:
            workdir = None
            if 'workdir' in arg:
                workdir = arg
                break
        if workdir is not None and workdir in workdirs:
            history = run.history()
            if len(history) == 0:
                continue
            config_suffix = workdirs2config[workdir][0]
            search_key = workdirs2config[workdir][1]

            args_config = f'configs/sic_mse/sic_mse_maud_{config_suffix}.py'
            for column in history.columns:
                if search_key in column and 'f1_metric' in column:
                    max_score_step = int(round(history['_step'][history[column].idxmax()]/10)*10)
                    pth_file = f'{workdirs_dir}/{workdir}/best_model_{workdir}_00{max_score_step:03}.pth'
                    if not os.path.exists(pth_file):
                        max_score_step -= 10
                        pth_file = f'{workdirs_dir}/{workdir}/best_model_{workdir}_00{max_score_step:03}.pth'
                    if not os.path.exists(pth_file):
                        print(f'FileNotFoundError {pth_file} not found')
        if pth_file is None:
            continue
        # load data and compute conf matrix
        ofile = 'conf_matrs_png/' + os.path.basename(pth_file).replace('.pth', '.png')
        ic(args_config)
        cfg = Config.fromfile(args_config)
        train_options = cfg.train_options
        # Get options for variables, amsrenv grid, cropping and upsampling.
        train_options = get_variable_options(train_options)
        device = 'cpu'
        net = get_model(train_options, device)
        net = load_model_cpu(net, pth_file)
        create_train_validation_and_test_scene_list(train_options)
        loss_ce_functions = {chart: get_loss(train_options['chart_loss'][chart]['type'], chart=chart, **train_options['chart_loss'][chart])
                                for chart in train_options['charts']}
        sir_name = train_options['charts'][0]
        dataset_val = AI4ArcticChallengeTestDataset(options=train_options, files=train_options['validate_list'][::5], mode='train')
        dataloader_val = torch.utils.data.DataLoader(
            dataset_val, batch_size=None, num_workers=train_options['num_workers_val'], shuffle=False)
        # - Stores the output and the reference pixels to calculate the scores after inference on all the scenes.
        outputs_flat = {chart: torch.Tensor().to(device) for chart in train_options['charts']}
        inf_ys_flat = {chart: torch.Tensor().to(device) for chart in train_options['charts']}
        for i, (inf_x, inf_y, cfv_masks, tfv_mask, name, original_size) in enumerate(tqdm(iterable=dataloader_val,
                                                                            total=len(train_options['validate_list']),
                                                                            colour='green')):
            with torch.no_grad():
                inf_x = inf_x.to(device, non_blocking=True)
                output = net(inf_x)
            for chart in train_options['charts']:
                output[chart] = class_decider(output[chart], train_options, chart)
                outputs_flat[chart] = torch.cat((outputs_flat[chart], output[chart][~cfv_masks[chart]]))
                #outputs_tfv_mask[chart] = torch.cat((outputs_tfv_mask[chart], output[chart][~tfv_mask]))
                inf_ys_flat[chart] = torch.cat((inf_ys_flat[chart], inf_y[chart][~cfv_masks[chart]].to(device, non_blocking=True)))
        cm = confusion_matrix(inf_ys_flat[sir_name], outputs_flat[sir_name])
        metrics = [
            metric_func(
                target=inf_ys_flat[sir_name],
                preds=outputs_flat[sir_name],
                average='weighted',
                task='multiclass',
                num_classes=13).item()
            for metric_func in metric_funcs
            ]
        title = ' '.join([f'{m:0.2}' for m in metrics])
        plt.figure(figsize=(3, 3))
        plt.imshow(cm, interpolation='nearest', cmap='Blues')
        plt.colorbar(shrink=0.5)
        plt.ylabel('Predicted')
        plt.xlabel('True')
        plt.title(title)
        plt.tight_layout()
        plt.savefig(ofile, pad_inches=0.1)
        plt.close()


WANDB: 0it [00:00, ?it/s]ic| args_config: 'configs/sic_mse/sic_mse_maud_01g.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 11/11 [00:12<00:00,  1.18s/it]
 22%|[32m██▏       [0m| 11/51 [00:33<02:00,  3.02s/it]
WANDB: 1it [00:52, 52.84s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_01f.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 11/11 [00:11<00:00,  1.02s/it]
 22%|[32m██▏       [0m| 11/51 [00:32<01:57,  2.94s/it]
WANDB: 4it [01:48, 18.88s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_01e.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 11/11 [00:15<00:00,  1.43s/it]
 22%|[32m██▏       [0m| 11/51 [02:37<09:30, 14.27s/it]
WANDB: 5it [05:03, 82.14s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_01d.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 11/11 [00:12<00:00,  1.14s/it]
 22%|[32m██▏       [0m| 11/51 [02:36<09:28, 14.22s/it]
WANDB: 21it [08:58,  6.77s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_01c.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 11/11 [00:12<00:00,  1.09s/it]
 22%|[32m██▏       [0m| 11/51 [00:33<02:02,  3.07s/it]
ic| args_config: 'configs/sic_mse/sic_mse_maud_01c.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 11/11 [00:10<00:00,  1.04it/s]
 22%|[32m██▏       [0m| 11/51 [00:33<02:00,  3.02s/it]
WANDB: 22it [10:43, 36.21s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_01b.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


Options train_list and validate_list initialised


100%|██████████| 16/16 [00:46<00:00,  2.91s/it]
 21%|[32m██        [0m| 16/78 [00:31<02:01,  1.96s/it]
WANDB: 24it [12:11, 36.32s/it]ic| args_config: 'configs/sic_mse/sic_mse_maud_01a.py'
  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
WANDB: 24it [12:13, 30.55s/it]

FileNotFoundError /Data/sim/antonk/post74_models/workdir_01a/best_model_workdir_01a_00030.pth not found





FileNotFoundError: [Errno 2] No such file or directory: '/Data/sim/antonk/post74_models/workdir_01a/best_model_workdir_01a_00030.pth'