In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("./../..")

In [None]:
import numpy as np
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision
import torch.nn as nn
from torch import optim
#
import matplotlib.pyplot as plt
from einops import rearrange
from tqdm import tqdm
#

from misc.plot_utils import plot_mat, imshow
from effcn.layers import FCCaps, FCCapsWOBias, Squash
from misc.utils import count_parameters
from effcn.functions import margin_loss
from datasets import AffNIST
#
from einops import rearrange, repeat
from torch import einsum, nn
#
import helpers
#
# local imports
from datasets import AffNIST
from effcn.layers import Squash
from effcn.functions import margin_loss, max_norm_masking
from misc.utils import count_parameters
from misc.plot_utils import plot_couplings, plot_capsules, plot_mat, plot_mat2
from misc.metrics import *

In [None]:
import sys
sys.path.append("./../..")

# standard lib
import shutil
from pathlib import Path

# external imports
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm
import scipy as sp
import pandas as pd
pd.options.display.float_format = '{:,.2f}'.format
import pickle
from torch.utils.data import DataLoader

# local imports
from datasets.csprites import ClassificationDataset
from effcn.layers import Squash
from effcn.functions import margin_loss, max_norm_masking
from misc.utils import count_parameters
from misc.plot_utils import plot_couplings, plot_capsules, plot_mat, plot_mat2
from misc.metrics import *
from misc.utils import normalize_transform, inverse_normalize_transform

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torch.nn as nn

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:1" 
else:  
    dev = "cpu"  
device = torch.device(dev)

In [None]:
def pos_tanh_embedding(h, w, t_freq = 2, t_symm = 0.5, scale=True):
    pe = torch.zeros(4, h, w)
    pe[0] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
    pe[1] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, w) - t_symm)).unsqueeze(1).repeat(1, h)) * 0.5
    pe[2] = (1 - torch.tanh(t_freq * (torch.linspace(0, 1, h) - t_symm)).T.repeat(w, 1)) * 0.5
    pe[3] = (1 - torch.tanh(t_freq * (torch.linspace(1, 0, h) - t_symm)).T.repeat(w, 1)) * 0.5
    if scale:
        pe = (pe - pe.min()) / (pe.max() - pe.min()) 
    return pe

In [None]:
pe = pos_tanh_embedding(28, 28)

In [None]:
pe.shape, pe.min(), pe.max()

In [None]:
b, h, w, c = 2, 28, 28, 3
#
x = torch.rand(b, c, h, w)
#
pe = pos_tanh_embedding(h, w)

# repeat stuff
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1)
x = torch.cat([x, pe], dim=1)
#
print(x.shape)

In [None]:
pe.reshape(b, 4, h*w).permute(0,2,1).shape

In [None]:
from math import log, pi
from einops import rearrange, repeat
from torch import einsum, nn

def fourier_encode(x, max_freq, num_bands=4, base=2):
    x = x.unsqueeze(-1)
    device, dtype, orig_x = x.device, x.dtype, x

    scales = torch.logspace(1.,
                            log(max_freq / 2) / log(base),
                            num_bands,
                            base=base,
                            device=device,
                            dtype=dtype)
    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]

    x = x * scales * pi
    x = torch.cat([x.sin(), x.cos()], dim=-1)
    x = torch.cat((x, orig_x), dim=-1)

    return x


def pos_embedding_fourier(data, max_freq=10, num_bands=2, base=2):
    b, *axis, _ = data.shape
    axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size), axis))
    pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
    enc_pos = fourier_encode(
        pos,
        max_freq,
        num_bands,
        base)
    enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
    enc_pos = repeat(enc_pos, '... -> b ...', b=b)
    return enc_pos

def pos_fourier_embedding(h, w, max_freq=10, num_bands=2, base=2):
    axis = (h, w)
    axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size), axis))
    pos = torch.stack(torch.meshgrid(*axis_pos), dim=-1)
    enc_pos = fourier_encode(
        pos,
        max_freq,
        num_bands,
        base)
    enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
    return enc_pos


def pos_linear_embedding(h, w):
    # linear
    pe = torch.zeros(4, h, w)
    pe[0] = torch.linspace(0, 1, w).unsqueeze(1).repeat(1, h)
    pe[1] = torch.linspace(1, 0, w).unsqueeze(1).repeat(1, h)
    pe[2] = torch.linspace(0, 1, h).T.repeat(w, 1)
    pe[3] = torch.linspace(1, 0, h).T.repeat(w, 1)
    return pe

In [None]:
def rand_mask_1_out(b, n_q):
    mask = torch.ones(b, n_q)
    for b_idx in range(b):
        mask[b_idx][np.random.randint(n_q)] = 0
    return mask.bool()

def rand_mask(b, d, p_masked=0.2):
    mask = torch.FloatTensor(b, d).uniform_() > p_masked
    if not torch.all(mask.sum(dim=1) != 0):
        mask = rand_mask(b, d, p_masked)
    return mask

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class Attention(nn.Module):
    def __init__(self, d_q, d_kv, n_heads, d_head, dropout=0.0, scale=None, d_out=None):
        super().__init__()
        self.n_heads = n_heads
        self.d_inner = d_head * n_heads
        self.d_out = d_out
        if d_out is None:
            self.d_out = d_q
        if scale is None:
            self.scale = d_head**-0.5
        #
        self.to_q = nn.Linear(d_q, self.d_inner, bias=False)
        self.to_kv = nn.Linear(d_kv, self.d_inner * 2, bias=False)
        #
        self.to_out = nn.Sequential(
            nn.Linear(self.d_inner, self.d_out),
            nn.Dropout(dropout)
        )
    
    def forward(self, x_q, x_kv, mask=None):
        """
            mask (b, d_kv): False if input should be ignored
        """
        h = self.n_heads
        q = self.to_q(x_q)
        k, v = self.to_kv(x_kv).chunk(2, dim=-1)
        #
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h),(q, k, v))
        sim = einsum("b i d, b j d -> b i j", q, k) * self.scale

        if mask is not None:
            mask = rearrange(mask, "b ... -> b (...)")
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, "b j -> (b h) () j", h=h)
            sim.masked_fill_(~mask, max_neg_value)
        attn = sim.softmax(dim=1)
        out = einsum("b i j, b j d -> b i d", attn, v)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
        return self.to_out(out)

In [None]:
def mask_template(n, n_masked=None, p_masked=None):
    if n_masked is None:
        assert p_masked > 0
        n_masked = int(p_masked * n)
    else:
        assert n_masked > 0
    mask_temp = torch.ones(n)
    mask_temp[:n_masked] = 0
    return mask_temp.bool()

def batch_mask_generator(b, n, n_masked=None, p_masked=None):
    mask_temp = mask_template(n, n_masked, p_masked)
    mask_temp = mask_temp.unsqueeze(0).repeat(b, 1)
    
    def _mask_generator():
        indices = torch.argsort(torch.rand(*mask_temp.shape), dim=-1)
        mask = mask_temp[torch.arange(mask_temp.shape[0]).unsqueeze(-1), indices]
        return mask 
    return _mask_generator

def mask_generator(b, n, n_masked=None, p_masked=None):
    mask_temp = mask_template(n, n_masked, p_masked)
    mask_temp = mask_temp.unsqueeze(0).repeat(b, 1)
    #
    indices = torch.argsort(torch.rand(*mask_temp.shape), dim=-1)
    mask = mask_temp[torch.arange(mask_temp.shape[0]).unsqueeze(-1), indices]
    return mask

def masked_select(x, mask):
    b, _, d = x.shape
    assert len(x.shape) == 1 + len(mask.shape)
    mask = mask.unsqueeze(-1)
    return torch.masked_select(x, mask).reshape(b, -1, d)

In [None]:
b = 2
d_q = 3
d_kv = 4
n_heads = 1
d_head = 6
dropout = 0.0
#
n_q = 7
n_kv = 8
#
model = Attention(
    d_q = d_q,
    d_kv = d_kv,
    n_heads = n_heads,
    d_head = d_head,
    dropout = dropout
)
#
x_q = torch.rand(b, n_q, d_q)
x_kv = torch.rand(b, n_kv, d_kv)
#
y = model.forward(x_q, x_kv)

In [None]:
model.to_out

In [None]:
x_q = torch.rand(1, n_q, d_q)
x_kv = torch.rand(b, n_kv, d_kv)
#
y = model.forward(x_q, x_kv)

In [None]:
class GEGLU(nn.Module):

    def forward(self, x):
        x, gates = x.chunk(2, dim=-1)
        return x * F.gelu(gates)
    
class CapsuleFeedForward(nn.Module):
    def __init__(self, n, d, dropout=0., mult=4):
        super().__init__()
        self.W1 = torch.nn.Parameter(
            torch.rand(n, d, d * mult),
            requires_grad=True)
        self.W2 = torch.nn.Parameter(
            torch.rand(n, d * mult, d),
            requires_grad=True)
        self.relu = nn.ReLU()
        #self.geglu = GEGLU()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = torch.einsum("nij,bni->bnj", self.W1, x)
        x = self.relu(x)
        #x = self.geglu(x)
        x = self.dropout(x)
        x = torch.einsum("nji,bnj->bni", self.W2, x)
        return x

In [None]:
b = 2
n = 3
d = 4
mult = 2
dropout = 0.
#
model = CapsuleFeedForward(n=n, d=d, dropout=dropout, mult=mult)
#
x = torch.rand(b, n, d)
y = model(x)
#
y.shape

In [None]:
class FeedForward(nn.Module):

    def __init__(self, d, mult=4, dropout=0., d_out=None):
        super().__init__()
        if d_out is None:
            d_out = d
        self.net = nn.Sequential(nn.Linear(d, d * mult),
                                 nn.ReLU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(d * mult, d_out))

    def forward(self, x):
        return self.net(x)

### Model

In [None]:
b, h, w, c = 2, 32, 32, 3
n_masked = 128
#
pe = pos_tanh_embedding(h, w)
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1)

In [None]:
x = torch.rand(b, c, h, w)
x_emb = torch.cat([x, pe], dim=1)

print(x_emb.shape)
x_emb = x_emb.permute(0,2,3,1).reshape(b, h*w,-1)
print(x_emb.shape)
_, n, d = x_emb.shape
#
mask = mask_generator(b, h*w, n_masked=n_masked)
print(mask.shape)

In [None]:
m = mask.reshape(b, h, w)
#
plt.imshow(m[0])
plt.show()

In [None]:
lay_at_up = [
    Attention(d_q = 32, d_kv = 7, n_heads = 1, d_head = 32,dropout = 0),
    Attention(d_q = 32, d_kv = 32, n_heads = 1, d_head = 32, dropout = 0)
]
lay_at_down = [
    Attention(d_q = 32, d_kv=32, n_heads = 1, d_head = 32, dropout = 0),
    Attention(d_q = 4, d_kv=32, n_heads = 1, d_head = 32, d_out=32, dropout = 0),
]
lay_ff_up = [
    CapsuleFeedForward(n=64, d=32, dropout=0, mult=4),
    CapsuleFeedForward(n=64, d=32, dropout=0, mult=4)
]
lay_ff_down = [
    CapsuleFeedForward(n=64, d=32, dropout=0, mult=4),
    FeedForward(d=32, dropout=0, mult=4, d_out=c)
]
LQS = [
    torch.rand(1, 64, 32),
    torch.rand(1, 64, 32),
]

In [None]:
zq = pe.permute(0,2,3,1).reshape(b, h*w,-1)

z = x_emb
print(z.shape)
zm = masked_select(zq, ~mask)
xm = masked_select(x.permute(0,2,3,1).reshape(-1,h*w,c), ~mask)
#
z = lay_at_up[0](LQS[0], z)
z = lay_ff_up[0](z)
z = lay_at_up[1](LQS[1], z)
z = lay_ff_up[1](z)
#
z = lay_at_down[0](LQS[0], z)
z = lay_ff_down[0](z)
z = lay_at_down[1](zm, z)
z = lay_ff_down[1](z)
#
print(z.shape)

In [None]:
class EchoModel(nn.Module):
    def __init__(self, n_h, d_h, d_in, d_e, c):
        super().__init__()
        #
        self.lay_at_up = nn.ModuleList([
            Attention(d_q = d_h, d_kv = d_in, n_heads = 1, d_head = d_h, dropout = 0.1),
            Attention(d_q = d_h, d_kv = d_h, n_heads = 1, d_head = d_h, dropout = 0.1),
            Attention(d_q = d_h, d_kv = d_h, n_heads = 1, d_head = d_h, dropout = 0.1),
        ])
        self.lay_at_down = nn.ModuleList([
            Attention(d_q = d_h, d_kv=d_h, n_heads = 1, d_head = d_h, dropout = 0.1),
            Attention(d_q = d_h, d_kv=d_h, n_heads = 1, d_head = d_h, dropout = 0.1),
            Attention(d_q = d_e, d_kv=d_h, n_heads = 1, d_head = d_h, d_out=d_h, dropout = 0),
        ])
        self.lay_ff_up = nn.ModuleList([
            FeedForward(d=d_h, dropout=0, mult=4, d_out=d_h),
            FeedForward(d=d_h, dropout=0, mult=4, d_out=d_h),
            FeedForward(d=d_h, dropout=0, mult=4, d_out=d_h),
            #CapsuleFeedForward(n=n_h, d=d_h, dropout=0.1, mult=4),
            #CapsuleFeedForward(n=n_h, d=d_h, dropout=0.1, mult=4),
            #CapsuleFeedForward(n=n_h, d=d_h, dropout=0.1, mult=4)
        ])
        self.lay_ff_down = nn.ModuleList([
            #CapsuleFeedForward(n=n_h, d=d_h, dropout=0.1, mult=4),
            #CapsuleFeedForward(n=n_h, d=d_h, dropout=0.1, mult=4),
            FeedForward(d=d_h, dropout=0, mult=4, d_out=d_h),
            FeedForward(d=d_h, dropout=0, mult=4, d_out=d_h),
            FeedForward(d=d_h, dropout=0, mult=4, d_out=c)
        ])
        self.LQS = nn.ParameterList([
            nn.Parameter(torch.rand(1, n_h, d_h), requires_grad=True),
            nn.Parameter(torch.rand(1, n_h, d_h), requires_grad=True),
            nn.Parameter(torch.rand(1, n_h, d_h), requires_grad=True),
        ])
        self.to_out = nn.Sigmoid()
        
    def forward(self, z, zq):
        """
            In:
                z  ... Embedding (b, n, d)
                zq ... PE Query  (b, m, dq)
            Out: z
        """
        # UP
        z = self.lay_at_up[0](self.LQS[0], z)
        z = self.lay_ff_up[0](z)
        z = self.lay_at_up[1](self.LQS[1], z)
        z = self.lay_ff_up[1](z)
        z = self.lay_at_up[2](self.LQS[2], z)
        z = self.lay_ff_up[2](z)
        
        # DOWN
        z = self.lay_at_down[0](self.LQS[1], z)
        z = self.lay_ff_down[0](z)
        z = self.lay_at_down[1](self.LQS[0], z)
        z = self.lay_ff_down[1](z)
        z = self.lay_at_down[2](zq, z)
        z = self.lay_ff_down[2](z)
        
        # OUT
        z = self.to_out(z)
        return z

In [None]:
b, h, w, c = 2, 32, 32, 3
n_masked = 256
#
n_h = 32
d_h = 16
d_in = 7
d_e = 4

# MASK GENERATOR
gen_mask = batch_mask_generator(b, h*w, n_masked=n_masked)

# POSTIONAL ENCODING
pe = pos_tanh_embedding(h, w)
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1)

In [None]:
# INPUT [for each step]
mask = gen_mask()

x = torch.rand(b, c, h, w)
x_emb = torch.cat([x, pe], dim=1)
#
x_tar = masked_select(x.permute(0,2,3,1).reshape(b,h*w,c), ~mask)
x_que = masked_select(pe.permute(0,2,3,1).reshape(b,h*w,-1), ~mask)
x_inp = masked_select(x_emb.permute(0,2,3,1).reshape(b,h*w,-1), mask)
#
print(x_tar.shape)
print(x_que.shape)
print(x_inp.shape)

In [None]:
model = EchoModel(n_h=n_h, d_h=d_h, d_in=d_in, d_e=d_e, c=c)
print("#params: {}".format(count_parameters(model)))
#model

In [None]:
x_pre = model.forward(x_inp, x_que)
x_pre.shape

In [None]:
print(x_pre.min(), x_pre.max(), x_pre.shape)
print(x_tar.min(), x_tar.max(), x_tar.shape)

In [None]:
nn.functional.mse_loss(x_pre, x_tar)

# Train

In [None]:
# black background
p_data = '/mnt/data/csprites/single_csprites_32x32_n7_c24_a12_p6_s2_bg_1_constant_color_145152'

# structured background
#p_data = '/mnt/data/csprites/single_csprites_32x32_n7_c24_a12_p6_s2_bg_inf_random_function_145152'
p_data = '/home/matthias/projects/data/single_csprites_32x32_n7_c24_a12_p6_s2_bg_1_constant_color_145152'
#p_data = '/home/matthias//projects/data/single_csprites_32x32_n7_c24_a12_p6_s2_bg_inf_random_function_145152'

p_ds_config = Path(p_data) / "config.pkl"
with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)
target_variable = "shape"
target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
n_classes = ds_config["n_classes"][target_variable]
#
norm_transform = normalize_transform(ds_config["means"],
                               ds_config["stds"])
#
target_transform = lambda x: x[target_idx]
transform = T.Compose(
    [T.ToTensor(),
     #norm_transform,
    ])
inverse_norm_transform = inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)

In [None]:
# TRAIN
batch_size = 512
num_workers = 4
#
ds_train = ClassificationDataset(
    p_data = p_data,
    transform=transform,
    target_transform=target_transform,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False,
    drop_last=True,
)
# VALID
ds_valid = ClassificationDataset(
    p_data = p_data,
    transform=transform,
    target_transform=target_transform,
    split="valid"
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=batch_size,
    shuffle=True,
    num_workers = num_workers,
    pin_memory=False
)

In [None]:
n_vis = 64
x,y = next(iter(dl_train))
x = x[:n_vis]
y = y[:n_vis]
#
print(x.min(), x.max())
#
#x = inverse_norm_transform(x)
#
grid_img = torchvision.utils.make_grid(x, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))

In [None]:
n_masked = 256
h = w = 32
b = batch_size
n_h = 256
d_h = 64
d_e = 4
c = 3
d_in = d_e + c

# MASK GENERATOR
gen_mask = batch_mask_generator(b, h*w, n_masked=n_masked)

# POSTIONAL ENCODING
# TANH
pe = pos_tanh_embedding(h, w)
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1)

# FOURIER
pe = pos_fourier_embedding(h, w, max_freq=10, num_bands=4, base=2)
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1).permute(0, 3, 1, 2)

# LINEAR
pe = pos_linear_embedding(h, w)
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1)

pe.shape

In [None]:
model = EchoModel(n_h=n_h, d_h=d_h, d_in=d_in, d_e=d_e, c=c)
model = model.to(device)
print("#params: {}".format(count_parameters(model)))
#model

In [None]:
optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.96)
#
loss_fn = torch.nn.MSELoss()

In [None]:
# #################
# overfit
# #################
b = 1

# MASK GENERATOR
gen_mask = batch_mask_generator(b, h*w, n_masked=n_masked)

# POSTIONAL ENCODING
pe = pos_linear_embedding(h, w)
pe = pe.unsqueeze(0).repeat(b, 1, 1, 1)

x, _ = next(iter(dl_train))
x = x[:b]
x = x.to(device)
num_epochs = 10001
pe = pe.to(device)
#
for epoch_idx in range(num_epochs):
    model.train()
    #
    
    mask = gen_mask()
    mask = mask.to(device)
    x_emb = torch.cat([x, pe], dim=1)
    #
    x_tar = masked_select(x.permute(0,2,3,1).reshape(b,h*w,c), ~mask)
    x_que = masked_select(pe.permute(0,2,3,1).reshape(b,h*w,-1), ~mask)
    x_inp = masked_select(x_emb.permute(0,2,3,1).reshape(b,h*w,-1), mask)
    
    optimizer.zero_grad()    
    x_pre = model.forward(x_inp, x_que)
    loss = loss_fn(x_pre, x_tar)

    loss.backward()
        
    optimizer.step()
    
    if epoch_idx % 100 == 0:
        print(loss.item())

In [None]:
x = x.cpu()
x_pre = x_pre.detach().cpu()
x_tar = x_tar.cpu()
x_que = x_que.cpu()
x_emb = x_emb.cpu()

In [None]:
n_vis = 10
n_vis = min(x.shape[0], n_vis)
for idx in range(n_vis):
    x_ori = x[idx].permute(1,2,0)
    m = mask[idx]
    x_rec = torch.clone(x_ori).reshape(h*w, -1)
    x_rec[~m] = x_pre[idx]
    x_rec = x_rec.reshape(h,w,c)
    x_mask = torch.zeros(h*w)
    x_mask[~m] = 1
    x_mask = x_mask.reshape(h, w)
    
    # DIFF
    diff = torch.abs(x_ori.reshape(h*w,-1)[~m] - x_pre[idx])
    x_diff = torch.zeros(h * w, c)
    x_diff[~m] = diff
    x_diff = x_diff.reshape(h, w, c)
    #
    fig, axes = plt.subplots(1,4,figsize=(20, 5))
    axes[0].imshow(x_ori)
    axes[1].imshow(x_mask)
    axes[2].imshow(x_rec)
    axes[3].imshow(x_diff)
    #axes[1].imshow()
    plt.show()

In [None]:
num_epochs = 11
pe = pe.to(device)
#
for epoch_idx in range(num_epochs):
    # ####################
    # TRAIN
    # ####################
    model.train()
    desc = "Train [{:3}/{:3}]:".format(epoch_idx, num_epochs)
    pbar = tqdm(dl_train, bar_format=desc + '{bar:10}{r_bar}{bar:-10b}')
    
    epoch_loss = 0
    steps = 0n_vis = 8
    for x,_ in pbar:
        x = x.to(device)
        optimizer.zero_grad()
        
        # INPUT AND TARGET
        mask = gen_mask()
        mask = mask.to(device)
        x_emb = torch.cat([x, pe], dim=1)
        #
        x_tar = masked_select(x.permute(0,2,3,1).reshape(b,h*w,c), ~mask)
        x_que = masked_select(pe.permute(0,2,3,1).reshape(b,h*w,-1), ~mask)
        x_inp = masked_select(x_emb.permute(0,2,3,1).reshape(b,h*w,-1), mask)
        
        x_pre = model.forward(x_inp, x_que)
        
        loss = loss_fn(x_pre, x_tar)

        loss.backward()
        
        optimizn_vis = 8er.step()
        epoch_loss += loss.item()
        steps += 1
        
        pbar.set_postfix(
                {'loss': loss.item(),
                 'epoch': epoch_loss / steps
                }
        )
    lr_scheduler.step()

In [None]:
x,_ = next(iter(dl_valid))
#
x = x.to(device)
        
# INPUT AND TARGET
mask = gen_mask()
mask = mask.to(device)
x_emb = torch.cat([x, pe], dim=1)

x_tar = masked_select(x.permute(0,2,3,1).reshape(b,h*w,c), ~mask).cpu()
x_que = masked_select(pe.permute(0,2,3,1).reshape(b,h*w,-1), ~mask).cpu()
x_inp = masked_select(x_emb.permute(0,2,3,1).reshape(b,h*w,-1), mask).cpu()

with torch.no_grad():
    x_pre = model.forward(x_inp.to(device), x_que.to(device)).cpu()

x = x.cpu()

In [None]:
print(x_tar.min(), x_tar.max())
print(x_que.min(), x_que.max())
print(x_inp.min(), x_inp.max())
print(x_pre.min(), x_pre.max())

In [None]:
n_vis = 8
for idx in range(n_vis):
    x_ori = x[idx].permute(1,2,0)
    m = mask[idx]
    x_rec = torch.clone(x_ori).reshape(h*w, -1)
    x_rec[~m] = x_pre[idx]
    x_rec = x_rec.reshape(h,w,c)
    x_mask = torch.zeros(h*w)
    x_mask[~m] = 1
    x_mask = x_mask.reshape(h, w)
    
    # DIFF
    diff = torch.abs(x_ori.reshape(h*w,-1)[~m] - x_pre[idx])
    x_diff = torch.zeros(h * w, c)
    x_diff[~m] = diff
    x_diff = x_diff.reshape(h, w, c)
    #
    fig, axes = plt.subplots(1,4,figsize=(20, 5))
    axes[0].imshow(x_ori)
    axes[1].imshow(x_mask)
    axes[2].imshow(x_rec)
    axes[3].imshow(x_diff)
    #axes[1].imshow()
    plt.show()

In [None]:
torch.abs(x_ori.reshape(h*w,-1)[~m] - x_pre[idx])

In [None]:
x_diff[~m]

In [None]:
x_rec[~m] = x_pre[idx]

In [None]:
x_pre.shape

In [None]:
x_rec = torch.clone(x_ori)

In [None]:
x_rec.shape

In [None]:
x_pre.shape