### Multiscale Residual Spatiotemporal Vision Transformer (MR-ST-ViT | PixFormer)

### Misael M. Morales, 2024
***

In [None]:
from main import *

hete = Heterogeneity()

In [None]:
hete.make_dataloaders()

In [None]:
hete.trainer()

In [None]:
hete.tester()

In [None]:
hete.plot_losses()

***
# END

In [None]:
for i, (x,y) in enumerate(hete.train_dataloader):
    print(x.shape, y.shape)
    break

In [None]:
x_sample = x.reshape(32, 40, 3, 64, 64)
y_sample = y.reshape(32, 40, 2, 64, 64)

fig, axs = plt.subplots(5, 10, figsize=(20,6))
for i in range(5):
    for j in range(10):
        k, t = i*5, j*4
        axs[i,j].imshow(x_sample[k,t,0], cmap='jet')
        axs[i,j].set(xticks=[], yticks=[])
        axs[i,0].set(ylabel='# {}'.format(k))
        axs[0,j].set(title='t = {}'.format(t))
plt.tight_layout(); plt.show()

fig, axs = plt.subplots(5, 10, figsize=(20,6))
for i in range(5):
    for j in range(10):
        k, t = i*5, j*4
        axs[i,j].imshow(y_sample[k,t,1], cmap='jet')
        axs[i,j].set(xticks=[], yticks=[])
        axs[i,0].set(ylabel='# {}'.format(k))
        axs[0,j].set(title='t = {}'.format(t))
plt.tight_layout(); plt.show()

In [None]:
print(x.shape, '  | Original')

projection_dim = 64*4
latent_size    = 8

y = ViTencoder()(x)
print(y.shape, '      | Encoded')

y = y.view(-1, projection_dim, latent_size, latent_size)
print(y.shape, '  | Reshaped')

def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        SqueezeExcitation(out_channels, out_channels//4),
        nn.InstanceNorm2d(out_channels),
        nn.PReLU(),
        MultiScaleResidual(),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(out_channels*2, out_channels, kernel_size=3, padding=1))

y = conv_block(projection_dim, projection_dim//2)(y)
print(y.shape, '| ConvBlock 1')

y = conv_block(projection_dim//2, projection_dim//4)(y)
print(y.shape, ' | ConvBlock 2')

y = conv_block(projection_dim//4, projection_dim//8)(y)
print(y.shape, ' | ConvBlock 3')

y = nn.Conv2d(projection_dim//8, 2, kernel_size=3, padding=1)(y)
print(y.shape, '  | Out')

In [None]:
y_sample = y.reshape(32, 40, 2, 64, 64).detach().numpy()

fig, axs = plt.subplots(5, 10, figsize=(20,6))
for i in range(5):
    for j in range(10):
        k, t = i*5, j*4
        axs[i,j].imshow(y_sample[k,t,1], cmap='jet')
        axs[i,j].set(xticks=[], yticks=[])
        axs[i,0].set(ylabel='# {}'.format(k))
        axs[0,j].set(title='t = {}'.format(t))
plt.tight_layout(); plt.show()