In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import distributions as dist
from torch.utils.data import DataLoader, TensorDataset
from torch import optim

from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms as tr
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from pprint import pprint
from inpainting.custom_layers import Reshape
from inpainting.losses import nll_masked_batch_loss, r2_masked_batch_loss, r2_total_batch_loss, nll_masked_ubervectorized_batch_loss
from inpainting.inpainters.mnist import MNISTConvolutionalInpainter
from pathlib import Path

In [None]:
from inpainting.datasets.mnist import train_val_datasets
from inpainting.visualizations.digits import digit_with_mask as vis_digit_mask
from inpainting.training import train_inpainter
from inpainting.utils import classifier_experiment, inpainted

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = "white"

In [None]:
!ps aux | grep mprzewie

In [None]:
!echo $CUDA_VISIBLE_DEVICES
!nvidia-smi

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

In [None]:
experiment_path = Path("../results/mnist_convolutional")
# experiment_path.mkdir()

In [None]:
ds_train, ds_val = train_val_datasets("/home/mprzewiezlikowski/uj/data/")

fig, axes = plt.subplots(10, 10, figsize=(15, 15))
for i in range(100):
    (x,j), y = ds_train[i]
    ax = axes[i // 10, i%10]
    ax.set_title(f"{y}")
    vis_digit_mask(x, j,ax)
train_fig = plt.gcf()
train_fig.savefig(experiment_path / "train.png")
plt.show()

In [None]:
# classifier = MLPClassifier((100, 200, 10,), learning_rate_init=4e-3, max_iter=1000).fit(ds_train.X.reshape(-1, 64), ds_train.y)

In [None]:
batch_size=12
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_val = DataLoader(ds_val, batch_size, shuffle=True)

In [None]:
m_std = lambda x, j, p, m, a, d: m.std(dim=0).mean()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

inpainter = MNISTConvolutionalInpainter(n_mixes=1)
opt = optim.Adam(inpainter.parameters(), lr=4e-3, weight_decay=0)
n_epochs = 50
history = train_inpainter(
    inpainter, 
    dl_train, 
    dl_val, 
    opt, 
    loss_fn = nll_masked_ubervectorized_batch_loss, 
    n_epochs=n_epochs,
    losses_to_log=None,
    device=device,
#     tqdm_loader=True
)

In [None]:
with (experiment_path / "inpainter.schema").open("w") as f:
    print(inpainter, file=f)

# torch.save(inpainter.state_dict, experiment_path / "inpainter.state")

In [None]:
[h["losses"] for h in history];

In [None]:
history[-1]["losses"]["objective"]

In [None]:
history_tmp = history
history = history

In [None]:
for loss_name in set(history[0]["losses"].keys()):
    for fold in ["train", "val"]:
        
        plt.plot(
            list(range(len(history))),
            [h["losses"][loss_name][fold] for h in history],
            label=fold
        )
    plt.title(loss_name)
    plt.legend()
    fig = plt.gcf()
    fig.savefig(experiment_path / f"history.{loss_name}.png")
    plt.show()

In [None]:


skip = 15

fig, axes = plt.subplots(
    int(np.ceil(len(history) / skip)* 2), 
    3 + inpainter.n_mixes,
    figsize=(10,15)
)


for e, h in enumerate(history):
    if e % skip !=0 and e != (len(history) -1):
        continue
    
    for ax_no, fold in [(0,"train"), (1,"val")]:
        
        # 0 - gt
        x, j, p, m, a, d, y = [t[0] for t in  h["sample_results"][fold]]
        row_no = (e // skip)*2 + ax_no
        ax = axes[row_no, 0]
        ax.imshow(x.reshape(28,28), cmap="gray")
        ax.axis("off")
        ax.set_title(f"{e} {fold} y_gt = {y}")
        
        # 1 - masked
        ax = axes[row_no, 1]
        vis_digit_mask(x, j, ax)
#         y_masked_pred = classifier.predict(x.reshape(1, 64))[0]
#         ax.set_title(f"y_m = {y_masked_pred}")
        
        # 2 - inpainted
        ax = axes[row_no, 2]
        x_inp = x.copy()
        
        m_ind = np.random.choice(np.arange(m.shape[0]), p=p)
        m_inp = m[m_ind].reshape(x.shape)
        
        x_inp[j==0] = m_inp[j==0]
#         y_inp_pred = classifier.predict(x_inp.reshape(1, 64))[0]

        ax.imshow(x_inp.reshape(28,28), cmap="gray", vmin=0, vmax=1)
        ax.axis("off")
#         ax.set_title(f"y_inp = {y_inp_pred}")
        
        
        for i, m_ in enumerate(m):
            ax = axes[row_no, 3 + i]

            ax.imshow(m_.reshape(28,28), cmap="gray", vmin=0, vmax=1)
            ax.axis("off")
            p_form = int(p[i] * 100) / 100
            chosen = "chosen " if i == m_ind else "" 
            ax.set_title(chosen + f"M_{i}, p={p_form}")

epochs_fig = plt.gcf()
epochs_fig.savefig(experiment_path / "epochs_renders.png")

In [None]:
epochs_path = experiment_path / "epochs"
epochs_path.mkdir()

skip = 5
n_rows = 16


for e, h in enumerate(history):
    if e % skip !=0 and e != (len(history) -1):
        continue
    
    for ax_no, fold in [(0,"train"), (1,"val")]:
        
        X, J, P, M, A, D, Y = h["sample_results"][fold]
        fig, axes = plt.subplots(
            n_rows, 
            3 + inpainter.n_mixes,
            figsize=(10,15)
        )
        for row_no, (x, j, p, m ,a, d, y) in enumerate(zip(X, J, P, M, A, D, Y)):
            if row_no >= n_rows:
                continue
            
            ax = axes[row_no, 0]
            ax.imshow(x.reshape(28,28), cmap="gray")
            ax.axis("off")
#             ax.set_title(f"y_gt = {y}")

            # 1 - masked
            ax = axes[row_no, 1]
            vis_digit_mask(x, j, ax)
#             y_masked_pred = classifier.predict(x.reshape(-1, 64))[0]
#             ax.set_title(f"y_m = {y_masked_pred}")

            # 2 - inpainted
            ax = axes[row_no, 2]
            x_inp = x.copy()
            m_ind = np.random.choice(np.arange(m.shape[0]), p=p)
            m_inp = m[m_ind].reshape(x.shape)
            x_inp[j==0] = m_inp[j==0]
#             y_inp_pred = classifier.predict(x_inp.reshape(-1, 64))[0]

            ax.imshow(x_inp.reshape(28,28), cmap="gray")
            ax.axis("off")
#             ax.set_title(f"y_inp = {y_inp_pred}")
            # 3 - M
            
            for i, m_ in enumerate(m):
                ax = axes[row_no, 3 + i]

                ax.imshow(m_.reshape(28,28), cmap="gray", vmin=0, vmax=1)
                ax.axis("off")
                p_form = int(p[i] * 100) / 100
                chosen = "chosen " if i == m_ind else ""
                ax.set_title(chosen + f"M_{i}, p={p_form}")
        
        title = f"{e}_{fold}"
        plt.suptitle(title)
        plt.savefig(epochs_path / f"{title}.png")
#         plt.show()
            

# epochs_fig = plt.gcf()
# epochs_fig.savefig(experiment_path / "epochs_renders.png")