In [1]:
import os
from datetime import datetime, timedelta
import time
import glob

# My modules
from config import Config
from logger import init_logger
from common_utils import set_seeds, create_holdout_loader, get_valid_transforms
from train_loop_functions import ensemble_inference
from cassava_dataset import CassavaDataset

from torch.utils.data import DataLoader

import pandas as pd

from sklearn.metrics import accuracy_score

import torch

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
set_seeds(Config.seed)
LOGGER = init_logger() # uses Python's logging framework

## Load the models

In [4]:
experiment_name_dirs = ['exp1_88_Adam']
model_states = []
for experiment in experiment_name_dirs:
    base = Config.save_dir + f'/{experiment}'
    # find all files that end in "pth"
    model_filenames = glob.glob(base + '/*.pth')
    for f in model_filenames:
        print(f)
        checkpoint = torch.load(f)
        model_states.append(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint['model'])

print(f'Loaded {len(model_states)} models')

./trained-models/exp1_88_Adam/tf_efficientnet_b4_ns_fold0.pth
./trained-models/exp1_88_Adam/tf_efficientnet_b4_ns_fold1.pth
./trained-models/exp1_88_Adam/tf_efficientnet_b4_ns_fold2.pth
./trained-models/exp1_88_Adam/tf_efficientnet_b4_ns_fold3.pth
Loaded 4 models


In [5]:
def run_inference(model_states, model_arch, kaggle):
    LOGGER.info('========== Running inference ==========')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    loader, targets, df = None, None, None
    num_samples = 0    
    
    if not kaggle: 
        # read holdout set from csv... 
        df = pd.read_csv('./trained-models/exp6_sgd/holdout.csv', engine='python')
        loader, targets = create_holdout_loader(df, Config.train_img_dir)
    else: 
        df = pd.DataFrame()
        df['image_id'] = list(os.listdir(Config.test_img_dir))
        test_dataset = CassavaDataset(df, Config.test_img_dir, 
                                      transform=get_valid_transforms(Config.img_size),
                                      output_label=False)
        loader = DataLoader(test_dataset, batch_size=Config.valid_bs)
    num_samples = len(df)

    inference_start = time.time()
    
    predictions = ensemble_inference(model_states, model_arch, 
                            Config.num_classes, loader, num_samples, device, mode='vote', kaggle=kaggle)
    
    inference_elapsed = time.time() - inference_start
    
    LOGGER.info(f"Inference time: {str(timedelta(seconds=inference_elapsed))}")

    if not kaggle:
        holdout_accuracy = accuracy_score(y_true=targets, y_pred=predictions)
        LOGGER.info(f"Ensemble model holdout accuracy: {holdout_accuracy}")
    else: # make submission file
        submission = pd.DataFrame()
        submission['image_id'] = df['image_id']
        submission['label'] = predictions
        submission.to_csv('submission.csv', index=False)
    

In [8]:
run_inference(model_states, Config.model_arch, True)

100%|██████████| 4/4 [00:02<00:00,  1.84it/s]
Inference time: 0:00:02.170352
