After digging into the MAE's [codes](https://github.com/facebookresearch/mae/blob/main/models_mae.py), I realize the patchify codes are tailor-made for square images. Modifying the codes to fit our "1D images" (spectra) is not trivial, so I had to step back and use the BERT again as the backbone (the progress of `pilot_02.ipynb`) and patchify the spectra into 16*1 images in a naive way. 

BTW, the [notes](https://hackmd.io/lTqNcOmQQLiwzkAwVySh8Q) for MAE provide nice general understanding of the model, which is a good supplenment to the [original paper](https://arxiv.org/abs/2111.06377).

# Modify BERT-like model
I first want to check if the data format after patchifying is able to get through the BERT-like model. Also, the input data format is [batch_size, seq_len, embedding_dim] after torch.dataloader, which is different from format in the tutorial [seq_len, batch_size, embedding_dim]. The modifications are as follows:

1. Modify the operation to fit the new input data format.
1. Simply reshape the input data to mimic the patchification. The data shape become [batch size, spectrum length devided by patch size, patch size]. e.g., [4, 128, 16] for batch size 4, spectrum length 2048, and patch size 16.

Please note that the positional encoding is patch-wise, not pixel-wise. There is no positional encoding within a patch yet.

In [1]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        #pe = torch.zeros(max_len, 1, d_model)
        pe = torch.zeros(1, max_len, d_model)
        #pe[:, 0, 0::2] = torch.sin(position * div_term)
        #pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            (old) x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
            (modified) x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
            embedding_dim is 16 in our case (path size: 16 channel values)
        """
        #x = x + self.pe[:x.size(0)]
        x = x + self.pe[:x.size(1)]
        return self.dropout(x)

class TransformerModel(nn.Module):

    #def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
    def __init__(self, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        # no need to embed token because we are not dealing with words
        #self.embedding = nn.Embedding(ntoken, d_model)
        #self.d_model = d_model
        # no need to linearly transform the output because we aim for reconstructing the masked spectrum
        #self.linear = nn.Linear(d_model, ntoken)

        #self.init_weights()

   # these weights are not used in the current version
    #def init_weights(self) -> None:
    #    initrange = 0.1
    #    self.embedding.weight.data.uniform_(-initrange, initrange)
    #    self.linear.bias.data.zero_()
    #    self.linear.weight.data.uniform_(-initrange, initrange)

    #def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
    def forward(self, src: Tensor) -> Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
            the ntoken is 2048/16=128 in our case (spectrum length/patch size)
        """
        #src = self.embedding(src) * math.sqrt(self.d_model)
        src = src * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        #if src_mask is None:
        #    """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
        #    Unmasked positions are filled with float(0.0).
        #    """
        #    src_mask = nn.Transformer.generate_square_subsequent_mask(len(src)).to(device)
        #output = self.transformer_encoder(src, src_mask)
        #output = self.transformer_encoder(src)
        #output = self.linear(output)
        return src

## Reshape the input data

In [35]:
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    # We don't need the labels and transform for now
    def __init__(self, annotations_file, input_dir, target_dir, mask_dir):
        """
        input_dir: directory with masked spe files
        target_dir: directory with original spe files
        mask_dir: directory with boolean mask files
        """
        self.spe_info = pd.read_csv(annotations_file)
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.mask_dir = mask_dir
        
    def __len__(self):
        return len(self.spe_info)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.spe_info.iloc[idx, 0])
        target_path = os.path.join(self.target_dir, self.spe_info.iloc[idx, 0])
        mask_path = os.path.join(self.mask_dir, self.spe_info.iloc[idx, 0])

        # reshape to (128, 16)
        input_spe = np.loadtxt(input_path, delimiter=',', dtype=int).reshape(-1, 16)
        target_spe = np.loadtxt(target_path, delimiter=',', dtype=int).reshape(-1, 16)
        mask = np.loadtxt(mask_path, delimiter=',', dtype=int).reshape(-1, 16)

        output = {'input_spe': input_spe,
                  'target_spe': target_spe,
                  'mask': mask}
  
        return output

In [36]:
import numpy as np
np.array([1,2,3,4,5,6,7,8,9,10]).reshape(-1, 2)

array([[ 1,  2],
       [ 3,  4],
       [ 5,  6],
       [ 7,  8],
       [ 9, 10]])

In [37]:
from torch import Generator
from torch.utils.data import random_split
from torch.utils.data import DataLoader

dataset = CustomImageDataset('data/info_20231121.csv', 'data/masked', 'data/spe', 'data/mask')
data_train, data_test = random_split(dataset, [0.8, 0.2], generator=Generator().manual_seed(24))

train_dataloader = DataLoader(data_train, batch_size=4, shuffle=True)
batch = next(iter(train_dataloader))
print(batch['input_spe'].size())
print(batch)


torch.Size([4, 128, 16])
{'input_spe': tensor([[[99999999, 99999999, 99999999,  ..., 99999999, 99999999, 99999999],
         [       0,        0,        0,  ...,        0,        0, 99999999],
         [       0,        0, 99999999,  ...,        6,        8,       16],
         ...,
         [99999999,        0, 99999999,  ...,        0,        0, 99999999],
         [99999999, 99999999,        0,  ...,        0,        0, 99999999],
         [99999999,        0, 99999999,  ..., 99999999, 99999999,        0]],

        [[       0,        0,        0,  ..., 99999999,        0, 99999999],
         [       0, 99999999,        0,  ..., 99999999, 99999999,        0],
         [99999999,        0,        0,  ..., 99999999, 99999999,       12],
         ...,
         [       0,        0,        0,  ...,        0,        0,        0],
         [       0, 99999999,        0,  ...,        0, 99999999,        0],
         [99999999,        0, 99999999,  ...,        0, 99999999,        0]],

     

0 in mask is not masked. 1 in mask is masked.

## Initiate an instance
The model hyperparameters are defined below, which is identical to BERT-base (expect the head is increased to 16 to divide the emsize).



In [48]:
spe_len = 2048
patch_size = 16 
emsize = int(spe_len/patch_size)  # embedding dimension = 128 ptaches
d_hid = 768  # dimension of the feedforward network model in ``nn.TransformerEncoder``
nlayers = 12  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerEncoder``
nhead = 16  # number of heads in ``nn.MultiheadAttention``
dropout = 0.1  # dropout probability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerModel(emsize, nhead, d_hid, nlayers, dropout).to(device)

It shows no error! Okay, I'll modifiy the data properly and train the model.

# Modify MAE with our naive patchification
Actually, modifying the MAE is not that difficult. Below codes are from `models_mae.py` in [MAE repo](https://github.com/facebookresearch/mae/blob/main/models_mae.py). Our input dataset is also modified accordingly. 

In [2]:
from functools import partial

import torch
import torch.nn as nn

# skip the delicate patch embedding
#from timm.models.vision_transformer import PatchEmbed, Block
from timm.models.vision_transformer import Block

# it is also modified to match our naive patch embedding
from util.pos_embed import get_1d_sincos_pos_embed


class MaskedAutoencoderViT(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    # modify img_size to spe_size (2048)
    # no need to have in_chans and embed_dim because we areusing naive patch embedding
    def __init__(self, spe_size=2048, patch_size=16,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

        # --------------------------------------------------------------------------
        # MAE encoder specifics
                #self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.num_patches = int(spe_size/patch_size)  # 128
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size, bias=True) # decoder to patch
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        #pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
        pos_embed = get_1d_sincos_pos_embed(self.pos_embed.shape[-1], self.num_patches, cls_token=True)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        #decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True)
        decoder_pos_embed = get_1d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.num_patches, cls_token=True)
        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        # no need to initialize patch_embed because we are using naive patch embedding
        #w = self.patch_embed.proj.weight.data
        #torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def patchify(self, spes):
        """
        (old)
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)

        (modified)
        spe: (N, spe_size)
        x: (N, num_patches, patch_size)
        """
        #p = self.patch_embed.patch_size[0]
        #assert spes.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        assert spes.shape[1] % self.patch_size == 0

        #h = w = imgs.shape[2] // p
        #x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = spes.reshape(shape=(spes.shape[0], self.num_patches, self.patch_size))
        # our patch embedding is naive, no need to permute and reshape again
        #x = torch.einsum('nchpwq->nhwpqc', x) # relevant to x.permute(0, 3, 4, 1, 2, 5)
        #x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

    def unpatchify(self, x):
        """
        (old)
        x: (N, L, patch_size**2 *3)
        imgs: (N, 3, H, W)

        (modified)
        x: (N, num_patches, patch_size)
        spe: (N, spe_size)
        """
        #p = self.patch_size
        #h = w = int(x.shape[1]**.5)
        #assert h * w == x.shape[1]
        
        #x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        #x = torch.einsum('nhwpqc->nchpwq', x)
        #imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
        x = x.reshape(shape=(x.shape[0], self.num_patches * self.patch_size))
        return x

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))
        
        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        #x = self.patch_embed(x)
        x = self.patchify(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x

    def forward_loss(self, spes, pred, mask):
        """
        (old)
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 

        (modified)
        spes: [N, spe_size]
        pred: [N, num_patches, patch_size]
        mask: [N, num_patches], 0 is keep, 1 is remove
        """
        target = self.patchify(spes)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

    def forward(self, spes, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(spes, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, path_size]
        loss = self.forward_loss(spes, pred, mask)
        return loss, pred, mask
    
def mae_vit_base_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=768, depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_large_patch16_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def mae_vit_huge_patch14_dec512d8b(**kwargs):
    model = MaskedAutoencoderViT(
        patch_size=14, embed_dim=1280, depth=32, num_heads=16,
        decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks

In [3]:
noise = torch.rand(2, 5) 
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

print(noise)
print(ids_shuffle)
print(ids_restore)

tensor([[0.4163, 0.7833, 0.7865, 0.6378, 0.0169],
        [0.7195, 0.4671, 0.3825, 0.1838, 0.6893]])
tensor([[4, 0, 3, 1, 2],
        [3, 2, 1, 4, 0]])
tensor([[1, 3, 4, 2, 0],
        [4, 2, 1, 0, 3]])


In [4]:
len_keep = 3
ids_shuffle[:, :len_keep]

tensor([[4, 0, 3],
        [3, 2, 1]])

In [5]:
model = mae_vit_base_patch16()
print(model)

MaskedAutoencoderViT(
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (decoder_embed): Linear(in_features=768, out_features=512, bias=True)
  (decoder_blocks): ModuleList(
    (0-7): 8 x Block(
      (norm1): LayerNorm((512,), eps=1e

The version of timm is fixed in 0.4.5 as recommended by the MAE repo. Actually, MAE repo requires 0.3.2 in `main_pretrain.py`, but it has issue with the torch version. So I use 0.4.5 that mentioned in `mae_visualize.ipynb`. There will be some works regarding the version issue later since MAE doesn't provide detailed version information.

In [6]:
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, input_dir):
        """
        input_dir: directory with spe files
        """
        self.spe_info = pd.read_csv(annotations_file)
        self.input_dir = input_dir
        
    def __len__(self):
        return len(self.spe_info)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.spe_info.iloc[idx, 0])
        spe = np.loadtxt(input_path, delimiter=',', dtype=int)  
        return spe

In [None]:
spe = np.loadtxt('data/spe/1.csv', delimiter=',', dtype=int)
spe.shape

(2048,)

In [8]:
from torch import Generator
from torch.utils.data import random_split
from torch.utils.data import DataLoader

dataset = CustomImageDataset('data/info_20231121.csv', 'data/spe')
data_train, data_test = random_split(dataset, [0.8, 0.2], generator=Generator().manual_seed(24))

train_dataloader = DataLoader(data_train, batch_size=4, shuffle=True)
batch = next(iter(train_dataloader))
print(batch.size())
print(batch)


torch.Size([4, 2048])
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = mae_vit_base_patch16().to(device)

In [None]:
loss, _, _ = model(batch.to(device), mask_ratio=0.4)

It shows some errors. To debug more efficiently, I change to code in scripts. The folder structure is similiar to the MAE repo. 