In [1]:
import torch
from torch import nn

import os
import numpy as np
import pandas as pd

from torch.utils.data import Dataset, DataLoader

from multiprocessing import Pool

In [4]:
class SatelliteDataset(Dataset):
    def __init__(self, path):
        self.images = np.array(os.listdir(path))
        self.img_dir = path

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image = self.images[idx]
        #img = np.load(self.img_dir + image)

        print(image)

# Results depend on validation set, so we fix the seed
np.random.seed(0)
torch.manual_seed(0)

train_ds = SatelliteDataset("datasets/landsat8_train/train/")
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)

In [5]:
next(iter(train_dl))

Landsat8_SR_B4_-55.16_-3.73_2016_04_26.tiff
Landsat8_SR_B6_-55.06_-3.71_2014_11_15.tiff
Landsat8_SR_B4_-55.14_-3.95_2021_03_23.tiff
Landsat8_SR_B7_-55.16_-3.39_2018_11_26.tiff
Landsat8_ST_B10_-54.76_-3.43_2014_06_24.tiff
Landsat8_SR_B1_-54.92_-3.59_2015_01_18.tiff
Landsat8_SR_B4_-54.94_-4.11_2016_12_22.tiff
Landsat8_SR_B4_-54.66_-4.31_2020_04_21.tiff
Landsat8_SR_B6_-54.92_-3.63_2013_12_14.tiff
Landsat8_ST_B10_-54.70_-3.69_2013_09_25.tiff
Landsat8_ST_B10_-54.58_-4.35_2020_04_21.tiff
Landsat8_SR_B4_-54.94_-4.27_2021_06_27.tiff
Landsat8_SR_B5_-54.78_-3.61_2021_03_07.tiff
Landsat8_SR_B2_-55.06_-3.93_2015_10_01.tiff
Landsat8_SR_B6_-54.96_-4.39_2018_08_06.tiff
Landsat8_SR_B6_-54.66_-4.01_2019_03_18.tiff


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>

In [None]:
#https://www.researchgate.net/figure/The-Vision-Transformer-architecture-a-the-main-architecture-of-the-model-b-the_fig2_348947034
class Transformer(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.ln1 = nn.LayerNorm(emb_dim)
        self.mha = nn.MultiheadAttention(emb_dim, num_heads=4, batch_first=True) ##2?

        self.ln2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.GELU(),
            nn.Linear(4 * emb_dim, emb_dim),
            #nn.Dropout(0.1),
        )

    def forward(self, x):
        x_ln = self.ln1(x)
        att, _ = self.mha(x_ln, x_ln, x_ln)
        
        att = self.ln2(att + x)
        return self.mlp(att) + att

In [None]:
#https://medium.com/@14prakash/masked-autoencoders-9e0f7a4a2585
class MyMAE(nn.Module):
    def __init__(self, in_c, img_size, patch_size, emb_dim):
        super().__init__()
        assert img_size % patch_size == 0

        self.emb_dim = emb_dim
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (self.img_size // self.patch_size) ** 2

        self.palette_emb = nn.Linear(3 * 16, self.emb_dim)

        self.patch_embedding = nn.Conv2d(in_c, self.emb_dim, 
                                         kernel_size=patch_size, 
                                         stride=patch_size,
                                         bias=True)

        self.pos_embedding = nn.Parameter(torch.zeros(self.num_patches, self.emb_dim), requires_grad=False)
        
        self.encoder = nn.Sequential(*[Transformer(self.emb_dim) for _ in range(10)])
        self.decoder = nn.Sequential(*[Transformer(self.emb_dim) for _ in range(2)])

        self.decoder_emb_dim = self.emb_dim
        self.decoder_emb = nn.Linear(self.emb_dim, self.decoder_emb_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_emb_dim))
        torch.nn.init.normal_(self.mask_token, std=.02)

        self.decoder_pos_emb = nn.Parameter(torch.zeros(self.num_patches, self.decoder_emb_dim))

        self.img_recov = nn.Linear(self.decoder_emb_dim, in_c * (self.patch_size ** 2), bias=True)
        

    def forward(self, x, xp):
        bs, _, _, _ = x.shape
        device = x.device

        patches = self.patch_embedding(x)
        patches = patches.flatten(2, 3).transpose(1, 2)

        masked_entries = int(0.75 * self.num_patches)
        mask = torch.randperm(self.num_patches, device=device) #img size independent?
        mask = mask[:-masked_entries]

        pal_emb = self.palette_emb(xp.flatten(1, 2))
        pos_emb = self.pos_embedding[mask, :]

        tokens = patches[:, mask, :] + pos_emb[None, ...] + pal_emb[:, None, :]
        features = self.encoder(tokens)

        ###### bottleneck

        tokens = self.mask_token.repeat(bs, self.num_patches, 1)
        tokens[:, mask, :] = self.decoder_emb(features)

        tokens = tokens + self.decoder_pos_emb
        features = self.decoder(tokens)

        image = self.img_recov(features)

        image = image.transpose(1,2)
        image = nn.functional.fold(image, 
                                   kernel_size=self.patch_size, 
                                   output_size=self.img_size, 
                                   stride=self.patch_size)
        
        return image

In [None]:
model = MyMAE(10, 85, 5, 256)