# Model inference

In the previous notebook, we saw how to do model inference on the test set. Here, we show how to load an already trained/fine-tuned model and a dataset and then do model inference.

In [None]:
# solve issue with autocomplete
%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.interpolate import griddata
from torchvision import transforms

from mapreader import loader
from mapreader import classifier
from mapreader import load_patches
from mapreader import patchTorchDataset

try:
    import cartopy.crs as ccrs
    ccrs_imported = True
except ImportError:
    print(f"[WARNING] cartopy could not be imported!")
    print(f"[WARNING] cartopy is used for plotting the results on maps.")
    print(f"[WARNING] You can ignore this if you don't want to plot the results.")
    ccrs_imported = False

## Read patches (i.e., sliced images) and add metadata

First, we need to load a set of images/pathces. We use a CV model to do inference on these images.

In [None]:
mymaps = load_patches("./maps_tutorial/slice_50_50/*101168609*PNG", 
                      parent_paths="./maps_tutorial/map_101168609.png")

path2metadata = "./maps_tutorial/metadata.csv"
mymaps.add_metadata(metadata=path2metadata)

In [None]:
# Calculate coordinates and some pixel stats
mymaps.add_center_coord()
mymaps.calc_pixel_stats()

In [None]:
maps_pd, patches_pd = mymaps.convertImages(fmt="dataframe")
patches_pd.head()

In `.add_metadata`:

```python
# remove duplicates using "name" column
if columns == None:
    columns = list(metadata_df.columns)

if ("name" in columns) and ("image_id" in columns):
    print(f"Both 'name' and 'image_id' columns exist! Use 'name' to index.")
    image_id_col = "name"
if "name" in columns:
    image_id_col = "name"
elif "image_id" in columns:
    image_id_col = "image_id"
else:
    raise ValueError("'name' or 'image_id' should be one of the columns.")
```

The dataframe should have either `name` or `image_id` column, and that column should be the image ID (NOT the path to the image).

In [None]:
# Rename image_path to image_id
# This is needed later (see `.add_metadata`)
patches_pd = patches_pd.reset_index()
patches_pd.rename(columns={"index": "image_id"}, 
                  inplace=True)
patches_pd.head()

In [None]:
patches2infer = patches_pd[["image_path"]]
patches2infer

In [None]:
# XXX TESTING
# patches2infer = patches2infer[:1000]

## Add patches to `patchTorchDataset`

In [None]:
# ------------------
# --- Transformation
# ------------------
# FOR INCEPTION
#resize2 = 299
# otherwise:
resize2 = 224

# mean and standard deviations of pixel intensities in 
# all the patches in 6", second edition maps
normalize_mean = 1 - np.array([0.82860442, 0.82515008, 0.77019864])
normalize_std = 1 - np.array([0.1025585, 0.10527616, 0.10039222])
# other options:
# normalize_mean = [0.485, 0.456, 0.406]
# normalize_std = [0.229, 0.224, 0.225]

data_transforms = {
    'val': transforms.Compose(
        [transforms.Resize(resize2),
        transforms.ToTensor(),
        transforms.Normalize(normalize_mean, normalize_std)
        ]),
}


In [None]:
patches2infer_dataset = patchTorchDataset(patches2infer, 
                                          transform=data_transforms["val"])


## Load a classifier (normally trained in notebook 003)

In [None]:
myclassifier = classifier(device="default")
myclassifier.load("./models_tutorial/checkpoint_5.pkl")

In [None]:
# Add dataset to myclassifier
batch_size=64
myclassifier.add2dataloader(patches2infer_dataset, 
                            set_name="infer_test", 
                            batch_size=batch_size, 
                            shuffle=False, 
                            num_workers=0)


## Inference on `set_name`

In [None]:
myclassifier.inference(set_name="infer_test")


## Plot sample results

In [None]:
myclassifier.class_names

In [None]:
myclassifier.inference_sample_results(num_samples=8, 
                                      class_index=1, 
                                      set_name="infer_test",
                                      min_conf=50,
                                      max_conf=None)

## Add model inference outputs to `mymaps`

In [None]:
patches2infer['pred'] = myclassifier.pred_label
patches2infer['conf'] = np.max(np.array(myclassifier.pred_conf), 
                               axis=1)
patches2infer


In [None]:
patches_pd = \
    patches_pd.merge(patches2infer, 
                     how="outer",
                     on="image_path",
                     validate="1:1")

In [None]:
patches_pd.head()

In [None]:
mymaps.add_metadata(patches_pd, 
                    tree_level="child")

## Write outputs as CSVs, one file per map sheet

In [None]:
maps_pd, patches_pd = mymaps.convertImages(fmt="dataframe")
patches_pd.head()

In [None]:
output_dir = "./infer_output_tutorial"
os.makedirs(output_dir, exist_ok=True)

In [None]:
for one_map in list(maps_pd.index):
    # --- paths
    map_name = one_map.split(".")[0]
    patch2write = os.path.join(output_dir, f"patch_{map_name}.csv")
    sheet2write = os.path.join(output_dir, f"sheet_{map_name}.csv")
    # --- write outputs
    patches_pd[patches_pd["parent_id"] == one_map].to_csv(patch2write, index=False)
    maps_pd[maps_pd.index == one_map].to_csv(sheet2write, index=False)

## Load outputs and plot

Although we already have all the required dataframes/variables loaded, we re-load them here as this is a required step in most realistic applications.

In [None]:
mymaps = load_patches("./maps_tutorial/slice_50_50/*101168609*PNG", 
                      parent_paths="./maps_tutorial/*101168609*png")

# add metadata (using CSV files):
path2metadata = "./maps_tutorial/metadata.csv"
mymaps.add_metadata(metadata=path2metadata)

In [None]:
# load the CSV files which contain predictions/confidence/...
path2patch = glob.glob("./infer_output_tutorial/patch*101168609*csv")

for path2metadata in path2patch:
    print(path2metadata)
    mymaps.add_metadata(metadata=path2metadata, 
                        tree_level="child", 
                        delimiter=",")

# or directly:
# mymaps.add_metadata(patches_pd, tree_level="child")

Other ways to read:

- Load dataframes, add metadata:

```python
mymaps_filt = loader()

mymaps_filt.loadDataframe(parents=maps_pd, 
                          children_df=patches_filt)

# add metadata (using CSV files):
path2metadata = "./maps_tutorial/metadata.csv"
mymaps_filt.add_metadata(metadata=path2metadata)
```

- Load CSV files

```python
from mapreader import loader

mymaps = loader()
mymaps.load_csv_file(parent_path="./infer_output_tutorial/sheet_map_101168609.csv", 
                     child_path="./infer_output_tutorial/patch_map_101168609.csv")
```

In [None]:
# List of all parents
all_parents = mymaps.list_parents()

mymaps.show_par(all_parents[0], 
                value="pred",
                border=None,
                plot_parent=True,
                vmin=0, vmax=1,
                figsize=(20, 20),
                alpha=0.5, 
                colorbar="seismic")

In [None]:
maps_pd, patches_pd = mymaps.convertImages(fmt="dataframe")
print(len(patches_pd))
patches_pd.head()

In [None]:
# filter patches with NaNs
patches_filt = patches_pd[~patches_pd["pred"].isna()]
patches_filt = patches_pd[patches_pd["pred"] >= 0]
patches_filt["pred"].value_counts()

### other plots

In [None]:
patches_filt2plot = patches_filt[(patches_filt["mean_pixel_A"] > 0.01)]


In [None]:
plt.figure(figsize=(20, 10))
plt.scatter(patches_filt2plot["center_lon"].values, 
            patches_filt2plot["center_lat"].values, 
            c="k",
            s=1)
plt.xlabel("Longitude", size=30)
plt.ylabel("Latitude", size=30)
plt.xticks(size=24)
plt.yticks(size=24)
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
plt.scatter(patches_filt2plot["center_lon"].values, 
            patches_filt2plot["center_lat"].values, 
            c=patches_filt2plot["mean_pixel_RGB"].values,
            vmin=0.6, vmax=0.9,
            s=30)
plt.xlabel("Longitude", size=30)
plt.ylabel("Latitude", size=30)
plt.xticks(size=24)
plt.yticks(size=24)
plt.grid()
plt.show()

In [None]:
# inputs
vmin = 0.6
vmax = 0.92
levels = 15
ngridx = 200
ngridy = 200

grouped = patches_filt2plot.groupby("parent_id")

plt.figure(figsize=(20, 10))
for name, group in grouped:
    x = group["center_lon"].values
    y = group["center_lat"].values
    z = group["mean_pixel_RGB"].values

    # Create grid values first.
    xi = np.linspace(min(x), max(x), ngridx)
    yi = np.linspace(min(y), max(y), ngridy)
    zi = griddata((x, y), z, 
                  (xi[None, :], yi[:, None]), 
                  method='linear')

#     plt.contour(xi, yi, zi, 
#                 levels=levels, 
#                 linewidths=0.5, colors='k', 
#                 vmin=vmin, vmax=vmax)
    
    plt.contourf(xi, yi, zi, 
                 levels=levels, 
                 cmap="RdBu_r", 
                 vmin=vmin, vmax=vmax)
    
plt.colorbar()
plt.show()

# # Linearly interpolate the data (x, y) on a grid defined by (xi, yi).
# triang = tri.Triangulation(x, y)
# interpolator = tri.LinearTriInterpolator(triang, z)
# Xi, Yi = np.meshgrid(xi, yi)
# zi = interpolator(Xi, Yi)

In [None]:
# inputs
vmin=0.6
vmax=0.92
levels=15
ngridx = 200
ngridy = 200

if ccrs_imported:
    grouped = patches_filt2plot.groupby("parent_id")

    fig = plt.figure(figsize=(20, 10))
    ax = plt.axes(projection=ccrs.PlateCarree())

    #extent = [-8.08999993, 1.81388127, 49.8338702, 60.95000002]
    extent = [-0.45, 0.45, 51.3, 51.7] # extracted from metadata

    ax.set_extent(extent)
    ax.coastlines(resolution='10m', color='black', linewidth=1)

    for name, group in grouped:
        x = group["center_lon"].values
        y = group["center_lat"].values
        z = group["mean_pixel_RGB"].values

        # Create grid values first.
        xi = np.linspace(min(x), max(x), ngridx)
        yi = np.linspace(min(y), max(y), ngridy)
        zi = griddata((x, y), z, 
                      (xi[None, :], yi[:, None]), 
                      method='linear')

    #     plt.contour(xi, yi, zi, 
    #                 levels=levels, 
    #                 linewidths=0.5, colors='k', 
    #                 vmin=vmin, vmax=vmax,
    #                 transform=ccrs.PlateCarree())

        plt.contourf(xi, yi, zi, 
                     levels=levels, 
                     cmap="RdBu_r", 
                     vmin=vmin, vmax=vmax,
                     transform=ccrs.PlateCarree())

    ax.gridlines(draw_labels=True)#, xlocs=[150, 152, 154, 155])
    plt.show()
else:
    print(f"[WARNING] cartopy could not be imported!")
    print(f"[WARNING] cartopy is used for plotting the results on maps.")
    print(f"[WARNING] You can ignore this if you don't want to plot the results.")