In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
import torch
import numpy as np
from pathlib import Path
import pickle
import matplotlib.pyplot as plt
from inpainting.visualizations import visualizations_utils as vis
from inpainting.visualizations.digits import img_with_mask
import pandas as pd
from inpainting.evaluation import evaluation as ev
from tqdm import tqdm
from itertools import chain
from inpainting.evaluation.frechet_models import MNISTNet
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from sklearn.metrics import accuracy_score
from inpainting.evaluation import fid
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.patches as patches


In [None]:
# mnist_experiments_paths = {
#     "misgan": Path("../results/mnist/misgan/"),
#     "torch_mfa": Path("../../gmm_missing/models/mnist"),
#     "gmm_fullconv": Path("../results/mnist/long_trainings/fullconv_v1"),
#     "gmm_linear_heads": Path("../results/mnist/long_trainings/linear_v1"),
# }

celeba_32_experiments_paths = {
    "gmm_linear_heads": Path("../results/celeba/linear_heads/32x32/scripted_v2_after_fix"),
#     "gmm_fullconv": Path("../results/celeba/fullconv/32x32/a_ampl_0.2"),
    "torch_mfa": Path("../../gmm_missing/models/celeba_32_32"),
#     "gmm_fullconv_nll_1_mse_10": Path("../results/celeba/fullconv/32x32/ampl_0.2_nll_1_mse_10/"),
#     "gmm_fullconv_nll_1_mse_1": Path("../results/celeba/fullconv/32x32/ampl_0.2_nll_1_mse_1/"),
#     "gmm_fullconv_nll_1_mse_1_stopped_after_10_epochs": Path("../results/celeba/fullconv/32x32/ampl_0.2_nll_1_mse_1_for_10_epochs_and_then_0/"),


}

svhn_experiments_paths = {
    "dmfa_comp": Path("../results/inpainting/svhn/fullconv/complete_data/dmfa_mse_10_eps/"),
    "dmfa_incomp": Path("../results/inpainting/svhn/fullconv/incomplete_data/dmfa_mse_10_eps_v4_train_det/"),
    "torch_mfa": Path("../../gmm_missing/models/svhn_32_32/")
}
experiments_paths = svhn_experiments_paths

In [None]:
# !ls ../../gmm_missing/models/celeba_32_32

In [None]:
experiments_results = dict()

for (name, path) in experiments_paths.items():
    print(name)
    with (path / "val_predictions_16x16.pkl").open("rb") as f:
        experiments_results[name] = pickle.load(f)

In [None]:
experiments_results.keys()

In [None]:
# x, j, p, m ,a, d, y = experiments_results["torch_mfa"][0]

In [None]:
# [t.shape for t in [x, j, p, m, a, d, y]]

In [None]:
experiments_images = dict()

for (name, results) in experiments_results.items():
    experiments_images[name] = [(ev.outputs_to_images(*r), r[-1]) for r in results]

# NLL and MSE

In [None]:
ml_metrics = {
    name: [
        (ev.loss_like_metrics(res), res)
        for res in results
    ]
    for name, results in tqdm(experiments_results.items())
}

In [None]:
# ml_metrics_df[["imputer_model", "nll", "mse"]].to_csv("celeba_nll.csv")

In [None]:
ml_metrics_df = pd.DataFrame([
    {
        "imputer_model": model,
#         "smieja_nll": m[1][-1][1] if len(m[1][-1].shape) > 0 else None,
        **m[0]
    }
    for model, metrics in ml_metrics.items()
    for m in metrics[:20]
])

ml_metrics_df

In [None]:
ml_metrics_df.groupby("imputer_model").agg(["mean", "std"]).to_csv("svhn_summary.csv")

In [None]:
ml_metrics_df[["imputer_model", "nll", "mse"]].to_csv("svhn_nll.csv")

In [None]:
# mfa_df = ml_metrics_df[ml_metrics_df.imputer_model == "torch_mfa"]
# mfa_df

In [None]:
ml_metrics_df.groupby("imputer_model").agg(["mean", "std"])

In [None]:
226 / (16*16*3)

In [None]:
for metric in ["nll", "mse"]:
    plt.figure(figsize=(15, 5))
    plt.title(metric)
    sns.boxplot(
        data=ml_metrics_df,
        y=metric,
        # x="img_kind",
        x="imputer_model"
    )
    plt.show()

# Worst and best NLL cases for every model

In [None]:
n = 1
for model, metrics_with_cases in ml_metrics.items():
    metrics_with_cases = sorted(
        metrics_with_cases, key = lambda m_c: m_c[0]["nll"]
    )
    if model != "gmm_fullconv_nll_1_mse_1_stopped_after_10_epochs":
        continue
#     metrics_with_cases = [
#         m for m in metrics_with_cases if m[1][-1] != 1
#     ] 
    for (name, mwc) in [
        ("worst", metrics_with_cases[-n:]),
        ("best", metrics_with_cases[:n])
    ]:
        print(model, name, n)
        row_len = vis.row_length(
            *metrics_with_cases[0][1]
        )
        fig, ax = plt.subplots(nrows=n, ncols=row_len, figsize=(2* row_len, n * 2))
        for i, (m, c) in enumerate(mwc):
            vis.visualize_sample(
                *c,
                ax_row=ax[i]
            )
#             x_, j_, p_ ,m_, a_, d_, y_ = c
            
#             c = x_, j_, p_, m_+ 0.29, a_, d_, y_
#             vis.visualize_sample(
#                 *c,
#                 ax_row=ax[2*i+1]
#             )
            ax[i, 1].set_title(
                f"nll = {m['nll']:.2f}"
            )
        plt.show()



# Skimage metrics

In [None]:
def plot_exp_images(images_dicts, figsize=(15,15)):
    width = len(images_dicts[0][0])
    height = len(images_dicts)
    fig, ax = plt.subplots(height,width, figsize=figsize)
    for i, (imgs, label) in enumerate(images_dicts):
        for j, (k, v) in enumerate(imgs.items()):
            ax[i][j].imshow(v.squeeze(), cmap="gray", vmin=0, vmax=1)
            ax[i][j].set_title(k)

plot_exp_images(experiments_images["gmm_linear_heads"][:5], figsize=(10, 5))

In [None]:
experiments_metrics = {
    name: [
        {
            "metrics": ev.images_metrics(img_dict),
            "label": label
        }
        for (img_dict, label) in img_dicts
    ]
    for name, img_dicts in experiments_images.items()
}

In [None]:
per_image_metrics_df = pd.DataFrame([
    {
        "imputer_model": model,
        "label": label_metric["label"],
        **metric
    }
    for model, label_metrics in experiments_metrics.items()
    for label_metric in label_metrics
    for metric in label_metric["metrics"] #if metric["img_kind"] #in ["inpainted_means_0", "inpainted_samples_0"]
])
per_image_metrics_df

In [None]:
pimdf = per_image_metrics_df.drop("label", axis=1).groupby(
    ["imputer_model", "img_kind"]
).agg(
    lambda pts: "{0:.2f} ± {1:.2f}".format(np.mean(pts), np.std(pts) )
     ).reset_index()#.to_csv("celeba_ssim_psnr.csv")

pimdf[pimdf.img_kind=="inpainted_means_0"]

In [None]:
for metric in ["structural_similarity", "peak_signal_noise_ratio"]:
    plt.figure(figsize=(15, 5))
    plt.title(metric)
    sns.boxplot(
        data=per_image_metrics_df,
        y=metric,
        x="img_kind",
        hue="imputer_model"
    )
    plt.show()

# Example inpaintings of the same digit

In [None]:
experiments_images.keys()

exp_to_name = {
    "torch_mfa": "MFA",
    "dmfa_incomp": "DMFA",
}

experiments_images_rnd = {
    k: experiments_images[k] for k in exp_to_name.keys()
}

experiments_images_rnd.keys()

In [None]:
n_rows = 8
n_cols = (len(experiments_images_rnd.keys())) + 2
fig, ax = plt.subplots(
    nrows=n_rows, 
    ncols=n_cols,
    figsize=(n_cols*2, n_rows*2)
)
for i, (exp_name, imgs) in enumerate(experiments_images_rnd.items()):
    for c, img_name, brief_name in [
        ((i) + 2, "inpainted_means_0", "inpainted"),        
#         ((2*i) + 3, "means_0", "mean"),
    ]:
#         if exp_name == "torch_mfa": 
#             if brief_name == "mean":
#                 continue
#         else:
#             c -= 1
        print(exp_name, brief_name, c)
        ax[0, c].set_title(f"{exp_to_name[exp_name]} - {brief_name}")
        for j in range(n_rows):
            ys, xs = (imgs[j][0]["mask"][:, :, 0]-1).nonzero()
            x0, x1 = xs[0], xs[-1]
            y0, y1 = ys[0], ys[-1]
            
            rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=2, edgecolor="r", facecolor="none" )
            ax[j,c].imshow(imgs[j][0][img_name].squeeze(), vmin=0, vmax=1, cmap="gray")
            ax[j,c].add_patch(rect, )
            ax[j,c].axis("off")
        
        
    if i == 0:
        for c, img_name, ttl in [
            (0, "original", "Original image"),
            (1, "masked", "Image with\nmissing data")        
        ]:
            ax[0, c].set_title(f"{ttl}")
            for j in range(n_rows):
                ax[j,c].imshow(imgs[j][0][img_name].squeeze(), vmin=0, vmax=1, cmap="gray")
                ax[j,c].axis("off")
                
# fig.savefig("celeba_mosaic.png")

In [None]:
xs, ys = (imgs[j][0]["mask"][:, :, 0]-1).nonzero()

# More visualizations

In [None]:
n_rows = 20
n_cols = 10
fig, ax = plt.subplots(
    nrows=n_rows, 
    ncols=n_cols,
    figsize=(n_cols*2, n_rows*2)
)

for i in range(n_rows):
    row_ax = ax[i]
    our_sample = experiments_results["gmm_linear_heads"][i]
    torch_mfa_sample = experiments_images["torch_mfa"][i][0]
#     _, dummy_ax = plt.subplots(vis.row_length(*sample))
    
    vis.visualize_sample_for_paper(
        our_sample,
         torch_mfa_sample,
        ax_row=row_ax
    )

fig.savefig("celeba_mosaic.png")

In [None]:
n = 20
for model, metrics_with_cases in ml_metrics.items():
    row_len = 15
    fig, ax = plt.subplots(nrows=n, ncols=row_len, figsize=(2* row_len, n * 2))
    mwc = metrics_with_cases[:20]
    for i, (m, c) in enumerate(mwc):
        vis.visualize_n_samples(
            *c,
            ax_row=ax[i]
        )
    ax[0][0].set_title(model)
    plt.show()


# Frechet distance

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

ds_train = MNIST(Path.home() / "uj" / "data", train=True, download=True, transform=ToTensor())
ds_val = MNIST(Path.home() / "uj" / "data", train=False, download=True, transform=ToTensor())

dl_train = DataLoader(ds_train, 1024, shuffle=True)
dl_val = DataLoader(ds_val, 256, shuffle=False)

classifier = MNISTNet()
opt = Adam(classifier.parameters(), 4e-3)
loss_fn = CrossEntropyLoss()

classifier.to(device)
for i in range(5):
    classifier.train()
    for X, y in tqdm(dl_train):
        opt.zero_grad()
        X, y = [t.to(device) for t in [X, y]]
        y_pred, _ = classifier(X)
        loss = loss_fn(y_pred, y)
        loss.backward()
        opt.step()
    classifier.eval()
    accs = []
    for (X,y ) in dl_val:
        X, y = [t.to(device) for t in [X, y]]
        y_pred, _ = classifier(X)
        logits = y_pred.argmax(1)
        accs.append(accuracy_score(y.cpu().numpy(), logits.cpu().numpy()))
    print(i, np.mean(accs))

In [None]:
def images_loader(
    images,
    batch_size = 256
):
    return DataLoader(
        TensorDataset(
            torch.Tensor([
                img.transpose(2, 0, 1)
                for img in images
            ])
        ),
        batch_size=batch_size
    )
    

def frechet_distance(
    images_loader_1,
    images_loader_2,
    model=classifier
):
    (mu_1, s_1), (mu_2, s_2) = [
        fid.calculate_activation_statistics(
            il,
            len(il.dataset),
            model,
            feature_dim=128,
        )
        for il in [images_loader_1, images_loader_2]
    ]
    return fid.calculate_frechet_distance(
        mu_1, s_1, mu_2, s_2
    )

def grouped_by_kinds(images_dicts):
    return {
        k: [
            i_d[k] for (i_d, _) in images_dicts
        ]
        for k in images_dicts[0][0].keys()
    }
    


In [None]:
experiments_images_by_kinds = {
    name: grouped_by_kinds(img_dicts)
    for name, img_dicts in experiments_images.items()
}

In [None]:
frechet_dists = {
    name: {
        kind: frechet_distance(
            images_loader(kind_to_images["original"]),
            images_loader(images)
            )
        for kind, images in kind_to_images.items()
    }
    for name, kind_to_images in experiments_images_by_kinds.items()
}

In [None]:
frechet_dists_df = pd.DataFrame([
    {
        "imputer_model": name,
        "kind": kind,
        "frechet_distance": fd
    }
    for name, kind_to_fd in frechet_dists.items()
    for kind, fd in kind_to_fd.items()
])

frechet_dists_df

In [None]:
plt.figure(figsize=(15, 5))
sns.barplot(
    data=frechet_dists_df,
    y="frechet_distance",
    x="kind",
    hue="imputer_model",
)

In [None]:
frechet_dists_df.to_csv("frechet_dists_tmp.csv")