In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import distributions as dist
from torch.utils.data import DataLoader, TensorDataset
from torch import optim

from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms as tr
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from pprint import pprint
from inpainting.custom_layers import Reshape
from inpainting.inpainters.mnist import MNISTLinearInpainter, MNISTConvolutionalInpainter
from pathlib import Path
import pickle

from inpainting.losses import (
    _nll_masked_batch_loss,
    _batch_loss_fn, 
    _nll_masked_sample_loss_v2, 
    _nll_masked_ubervectorized_batch_loss_v1, 
    nll_masked_batch_loss_same_size_masks, 
    nll_masked_batch_loss,
    _nll_masked_sample_loss_v0,
)
from collections import defaultdict
import time

In [None]:
from inpainting.datasets.mnist import train_val_datasets
from inpainting.visualizations.digits import digit_with_mask as vis_digit_mask
from inpainting.training import train_inpainter
from inpainting.utils import classifier_experiment, inpainted
import inpainting.visualizations.samples as vis
from inpainting.datasets import mask_coding as mc
from inpainting.datasets.utils import RandomRectangleMaskConfig
import pandas as pd
import seaborn as sns

In [None]:
import matplotlib
matplotlib.rcParams['figure.facecolor'] = "white"

In [None]:
!ps aux | grep mprzewie

In [None]:
!echo $CUDA_VISIBLE_DEVICES
!nvidia-smi

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

In [None]:
experiment_path = Path("../results/mnist/missing_data_10x10")
experiment_path.mkdir(exist_ok=True)

In [None]:
ds_train, ds_val = train_val_datasets(
    "/home/mprzewiezlikowski/uj/data/",
    mask_configs=[
        RandomRectangleMaskConfig(mc.UNKNOWN_LOSS, 10, 10, 0,0),
#         RandomRectangleMaskConfig(mc.UNKNOWN_NO_LOSS, 10,10, 0,0)
    ]
)

fig, axes = plt.subplots(10, 10, figsize=(15, 15))
for i in range(100):
    (x,j), y = ds_train[i]
    ax = axes[i // 10, i%10]
    ax.set_title(f"{y}")
    vis_digit_mask(x, j,ax)
train_fig = plt.gcf()
train_fig.savefig(experiment_path / "train.png")
plt.show()

In [None]:
batch_size=96
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_val = DataLoader(ds_val, batch_size, shuffle=True)

In [None]:
m_std = lambda x, j, p, m, a, d: m.std(dim=0).mean()
a_max =  lambda x, j, p, m, a, d: a.max()

In [None]:
losses_to_benchmark = {
    "v0": _batch_loss_fn(_nll_masked_sample_loss_v0),
    "v1": _nll_masked_batch_loss,
    "v2": _batch_loss_fn(_nll_masked_sample_loss_v2),
    "v3": _nll_masked_ubervectorized_batch_loss_v1,
    "v4": nll_masked_batch_loss_same_size_masks,
    "v5": nll_masked_batch_loss
}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
# device = torch.device("cpu")
history = []
# inpainter = MNISTLinearInpainter(n_mixes=1, hidden_size=2048)
inpainter = MNISTConvolutionalInpainter(n_mixes=1)

opt = optim.Adam(inpainter.parameters(), lr=4e-5, weight_decay=0)
n_epochs = 50
benchmark_results = defaultdict(list)
l_values = defaultdict(list)

inpainter = inpainter.to(device)
inpainter.eval()

for i, ((x,j), y) in tqdm(enumerate(dl_train)):
    x, j, y = [t.to(device) for t in [x, j, y]]
    p, m, a, d = inpainter(x, j)
    
    for loss_name, l in losses_to_benchmark.items():
        s = time.time()
        
        v =  l(x, j, p, m, a, d)#.detach().cpu().numpy()
        e = time.time()
        
        benchmark_results[loss_name].append(e - s)
        l_values[loss_name].append(v)
    
    if i == 10:
        break

In [None]:
df_bench_cpu = pd.DataFrame(benchmark_results)

In [None]:
f, ax = plt.subplots(figsize=(7, 5))
ax.set(yscale="log")
sns.boxplot(
    x="loss_name",
    y="time",
    data=pd.melt(df_bench_cpu).rename({"variable": "loss_name", "value": "time"}, axis="columns"), ax=ax
)

In [None]:
l_values