In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from src.loaders import DataLoader
from src.models import load_model
from src.utils.loaders import load_x, load_variable

In [2]:
dataset_path = Path(".data")
dataset_name = "a2AR_basal_TC685_names"

config_path  = Path("configs/main/config.py")
config = load_variable("config", config_path)

In [3]:
_, _, test_loader = DataLoader.get_loaders(path_to_data=dataset_path, config=config)
data = torch.cat([x for x, _ in test_loader])

Applying train preprocessors ...:   0%|          | 0/1 [00:00<?, ?it/s]2021-03-29 03:32:36,777 INFO PrepareDataloader - Applying Permute on train set ...
Applying train preprocessors ...: 100%|██████████| 1/1 [00:00<00:00, 501.35it/s]
Applying valid preprocessors ...:   0%|          | 0/1 [00:00<?, ?it/s]2021-03-29 03:32:36,798 INFO PrepareDataloader - Applying Permute on valid set ...
Applying valid preprocessors ...: 100%|██████████| 1/1 [00:00<00:00, 501.47it/s]
Applying test preprocessors ...:   0%|          | 0/1 [00:00<?, ?it/s]2021-03-29 03:32:36,824 INFO PrepareDataloader - Applying Permute on test set ...
Applying test preprocessors ...: 100%|██████████| 1/1 [00:00<00:00, 334.15it/s]


In [4]:
model_path = Path(".data/models/resnet18/723/checkpoints/epoch=17.ckpt")
# model = load_model(config, model_path)

In [5]:
# predictions = model(data).argmax(dim=1)

In [6]:
from tools.predict import main as predict

prediction_path = Path(".data/predictions/dataset_name.npy")
predictions = predict(
    config_path=config_path,
    input_path="",
    model_path=model_path,
    predict_path=prediction_path,
    val_loader=test_loader,
)

Predictions: 100%|██████████| 1218/1218 [01:40<00:00, 12.10it/s]

Prediction took: 103.9655 sec


In [7]:
predictions = load_x(prediction_path).squeeze().argmax(dim=1)
predictions = [x.item() for x in predictions]

In [8]:
diffusion_type = ["Anomalous", "Confined", "Directed", "Normal"]
stats = {diffusion_type[i]: list(predictions).count(i) for i in range(4)}

stats_2 = {"Dataset": dataset_name}
stats_2.update(stats)
stats_2

pd.DataFrame.from_dict(stats_2, orient="index").T.style.hide_index()

Dataset,Anomalous,Confined,Directed,Normal
a2AR_basal_TC685_names,53,27,90,1048


In [9]:
names = np.load(f".data/{dataset_name}.npy", allow_pickle=True)[:, 0]

df = pd.DataFrame(zip(names, predictions))
df.columns = ["Name", "Class"]
df.to_csv(f".data/predictions/{dataset_name}.csv", index=False)

In [12]:
df_dict = {"Dataset": ["Gi_basal_TC685_names", "a2AR_basal_TC685_names"], "Anomalous":[ 68, 53], "Confined": [47, 27], "Directed": [71, 90], "Normal": [851, 1048]}
pd.DataFrame.from_dict(df_dict, orient="index").T.style.hide_index()
# Gi_basal_TC685_names	68	47	71	851
# a2AR_basal_TC685_names	53	27	90	1048


Dataset,Anomalous,Confined,Directed,Normal
Gi_basal_TC685_names,68,47,71,851
a2AR_basal_TC685_names,53,27,90,1048
