In [None]:
from pathlib import Path
import pandas as pd
import timm
import torch
from model import *
from utils import get_train_transform, get_device, set_seed
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

### Set dirs

In [None]:
root_dir = Path.cwd()
data_dir = root_dir.joinpath('data')
img_dir = data_dir.joinpath('test')
weights_dir = root_dir.joinpath('models', 'best_models')

In [None]:
df = pd.read_csv(data_dir.joinpath('sample_submission.csv'))
dataset = SetiDataset(df, img_dir)

In [None]:
device = get_device()

### Visualize Test Data

In [None]:
idx = 4
img, label = dataset.__getitem__(idx)
_, axs = plt.subplots(1, img.shape[0], figsize=(15,5))
for i, (ax, ch) in enumerate(zip(axs, img)):
    ax.imshow(ch)
    ax.axis('off')
    ax.set_title(f'CH {i}')

### Load models

In [None]:
efficientnet = timm.create_model('efficientnet_b0', pretrained=True)
net = DoubleNet(efficientnet).to(device)
checkpoint = torch.load(str(weights_dir.joinpath('best_checkpoint_fold_4_val_aucroc_0.981_010_epoch.bin')))
net.load_state_dict(checkpoint['model_state_dict'])

### Test model on data 

In [None]:
with torch.no_grad():
    for idx in tqdm(range(dataset.__len__())):
        img, _ = dataset.__getitem__(idx)
        img = torch.unsqueeze(torch.tensor(img / img.max(), dtype=torch.float), 0).to(device)
        label = (net.predict(img).cpu().numpy() > 0.5)[0, 0].astype(int)
        dataset.__change_label__(idx, label)

In [None]:
dataset.df['target']