## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys
from pathlib import Path

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
sys.path.append('/home/k64835/Master-Thesis-SITS')

scripts_path = Path("../Data-Preprocessing/").resolve()
sys.path.append(str(scripts_path))

scripts_path = Path("../Evaluation/").resolve()
sys.path.append(str(scripts_path))

In [10]:
import pickle
import math
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestCentroid
from scripts.data_visualiser import *
from sklearn.manifold import TSNE 
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from model_scripts.subpatch_extraction import *
from scripts.data_loader import *
from scripts.data_preprocessor import *
from scripts.temporal_data_preprocessor import *
from scripts.temporal_data_loader import *
from scripts.temporal_visualiser import *
from scripts.temporal_chanel_refinement import *
from model_scripts.model_helper import *
from model_scripts.dataset_creation import *
from model_scripts.train_model_ae import *
from model_scripts.model_visualiser import *
from model_scripts.clustering import *
from model_scripts.train_model_dcec import *
from evaluation_scripts.evaluation_helper import *
from evaluation_scripts.result_visualiser import *
from Pipeline.temporal_preprocessing_pipeline import *
from evaluation_scripts.result_visualiser import *
from Pipeline.temporal_preprocessing_pipeline import *
import numpy as np
import config as config
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
import skimage.measure
import torch
import torch.nn as nn
import torch.optim as optim

## Dataset prep: B10

### Loading the pre-processed data

Data: Extracted and Pre-processed Patches (each patch containing a sugarbeet field)

Dimensions: (N, T, C, H, W) = (N, 7, 10, 64, 64)

In [4]:
preprocessing_pipeline = PreProcessingPipelineTemporal()
field_numbers_train, acquisition_dates_train, patch_tensor_train, images_visualisation_train = preprocessing_pipeline.get_processed_temporal_cubes('train', 'allbands')
field_numbers_eval, acquisition_dates_eval, patch_tensor_eval, images_visualisation_eval = preprocessing_pipeline.get_processed_temporal_cubes('eval', 'allbands')
patch_tensor_train.shape, patch_tensor_eval.shape

(torch.Size([2425, 7, 10, 64, 64]), torch.Size([48, 7, 10, 64, 64]))

## MAE

In [18]:
import torch

def patchify(images, patch_size=4):

    B, T, C, H, W = images.shape
    assert H % patch_size == 0 and W % patch_size == 0, "Image dimensions must be divisible by patch_size"
    
    patches = images.unfold(3, patch_size, patch_size).unfold(4, patch_size, patch_size)
    patches = patches.contiguous().view(B, T, C, -1, patch_size, patch_size)  # (B, T, C, num_patches, patch_size, patch_size)

    num_patches = patches.shape[3]

    # Create a valid patch mask (1 for valid patches, 0 for invalid patches)
    valid_patch_mask = (patches.sum(dim=[4, 5]) != 0).float()  # Shape: (B, T, num_patches)
    
    return patches, valid_patch_mask


In [19]:
images = patch_tensor_train  # Example input tensor with shape (B=4, T=7, C=10, H=64, W=64)
patch_size = 4

patches, valid_patch_mask = patchify(images, patch_size)

print("Patches shape:", patches.shape)  # Should print: (B, T, C, num_patches, patch_size, patch_size)
print("Valid patch mask shape:", valid_patch_mask.shape)  # Should print: (B, T, C, num_patches)

Patches shape: torch.Size([2425, 7, 10, 256, 4, 4])
Valid patch mask shape: torch.Size([2425, 7, 10, 256])


### POS encoding

In [11]:
# Positional encoding function (sinusoidal)
class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Spatial Positional Encoding (2D)
        self.spatial_pos_encoding = self.get_2d_positional_encoding(num_patches, embed_dim)
        
        # Temporal Positional Encoding (1D)
        self.temporal_pos_encoding = self.get_1d_positional_encoding(max_len, embed_dim)
    
    def get_2d_positional_encoding(self, num_patches, embed_dim):
        # Get the 2D positional encoding (sine/cosine)
        grid_size = int(math.sqrt(num_patches))  # Assuming square grid
        
        # Generate meshgrid for spatial positions
        y_pos, x_pos = torch.meshgrid(torch.arange(grid_size), torch.arange(grid_size))
        y_pos = y_pos.flatten()
        x_pos = x_pos.flatten()
        
        # Create positional encoding matrix for each spatial position
        pos = torch.stack([x_pos, y_pos], dim=-1).unsqueeze(0)  # shape: (1, num_patches, 2)
        
        # Apply sine/cosine function
        angle_rates = 1 / torch.pow(10000, (2 * torch.arange(0, embed_dim, 2).float()) / embed_dim)
        angle_rads = pos @ angle_rates  # (1, num_patches, embed_dim // 2)
        
        # Apply sin and cos functions
        pos_encoding = torch.cat([torch.sin(angle_rads), torch.cos(angle_rads)], dim=-1)
        return pos_encoding
    
    def get_1d_positional_encoding(self, max_len, embed_dim):
        # Generate 1D positional encodings (sine/cosine)
        position = torch.arange(max_len).unsqueeze(1)  # shape: (max_len, 1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        angle_rads = position * div_term  # (max_len, embed_dim // 2)
        
        # Apply sin and cos functions
        pos_encoding = torch.cat([torch.sin(angle_rads), torch.cos(angle_rads)], dim=-1)  # shape: (max_len, embed_dim)
        return pos_encoding.unsqueeze(0)  # shape: (1, max_len, embed_dim)
    
    def forward(self, spatial_indices, temporal_indices):
        # spatial_indices: (batch_size, num_patches)
        # temporal_indices: (batch_size, T)
        
        spatial_encoding = self.spatial_pos_encoding[:, spatial_indices]  # shape: (batch_size, num_patches, embed_dim)
        temporal_encoding = self.temporal_pos_encoding[:, temporal_indices]  # shape: (batch_size, T, embed_dim)
        
        return spatial_encoding, temporal_encoding

### ViT Block

In [12]:
# ViT Block (Transformer Block)
class ViTBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x shape: (num_patches, batch_size, embed_dim)
        attn_out, _ = self.attn(x, x, x)
        x = self.ln1(attn_out + x)  # Add & Norm
        ffn_out = self.ffn(x)
        x = self.ln2(ffn_out + x)  # Add & Norm
        return x

### Encoder

In [14]:
# ViT Encoder
class ViTEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, num_layers, num_patches, patch_size=4, img_size=64, num_channels=10, max_len=5000):
        super(ViTEncoder, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        # Patch embedding layer
        self.patch_embed = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(embed_dim=embed_dim, num_patches=num_patches, max_len=max_len)
        
        # Transformer encoder layers
        self.transformer_blocks = nn.ModuleList(
            [ViTBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)]
        )

    def forward(self, x, temporal_indices):
        # x shape: (B, T, C, H, W) --> (B, T, num_patches, embed_dim)
        B, T, C, H, W = x.shape
        patches = self.patch_embed(x.view(B * T, C, H, W))  # (B * T, embed_dim, H/patch_size, W/patch_size)
        patches = patches.flatten(2).transpose(1, 2)  # (B * T, num_patches, embed_dim)
        
        # Spatial and Temporal Positional Encoding
        spatial_indices = torch.arange(patches.size(1)).unsqueeze(0).expand(B * T, -1).to(x.device)
        spatial_encoding, temporal_encoding = self.pos_encoder(spatial_indices, temporal_indices)
        
        # Add positional encodings
        patches = patches + spatial_encoding + temporal_encoding.unsqueeze(1)  # (B * T, num_patches, embed_dim)
        
        # Pass through transformer blocks
        for block in self.transformer_blocks:
            patches = block(patches)
        
        return patches

### Decoder

In [15]:
# ViT Decoder
class ViTDecoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, num_layers, num_patches, patch_size=16, img_size=64, num_channels=10):
        super(ViTDecoder, self).__init__()
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_patches = num_patches
        
        # Transformer decoder layers
        self.transformer_blocks = nn.ModuleList(
            [ViTBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)]
        )
        
        # Decoder output to image reconstruction
        self.fc = nn.Linear(embed_dim, num_channels * patch_size * patch_size)

        # Reshape into image space
        self.reshaper = nn.Unflatten(2, (num_channels, patch_size, patch_size))
    
    def forward(self, x):
        # x shape: (B * T, num_patches, embed_dim)
        B, T, num_patches, embed_dim = x.shape
        
        for block in self.transformer_blocks:
            x = block(x)
        
        # Reshape back to image patches
        x = self.fc(x)
        x = self.reshaper(x)  # (B * T, num_patches, C, patch_size, patch_size)
        x = x.view(B, T, num_patches, -1, self.patch_size, self.patch_size)  # (B, T, num_patches, C, patch_size, patch_size)
        
        return x

### MAE

In [16]:
# MAEViT Model (Masked Autoencoder with Vision Transformer)
class MAEViT(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, ff_dim=1024, num_layers=6, num_patches=10, patch_size=16, img_size=64, num_channels=10):
        super(MAEViT, self).__init__()
        self.encoder = ViTEncoder(embed_dim, num_heads, ff_dim, num_layers, num_patches, patch_size, img_size, num_channels)
        self.decoder = ViTDecoder(embed_dim, num_heads, ff_dim, num_layers, num_patches, patch_size, img_size, num_channels)

    def forward(self, x, temporal_indices):
        encoded_patches = self.encoder(x, temporal_indices)  # (B * T, num_patches, embed_dim)
        reconstructed = self.decoder(encoded_patches)  # (B, T, num_patches, C, patch_size, patch_size)
        return reconstructed

In [17]:
# Example usage
B = 4  # Batch size
T = 7  # Number of temporal images (7 acquisition dates)
C = 10  # Number of channels
H = 64  # Image height
W = 64  # Image width

images = torch.randn(B, T, C, H, W)  # Example batch of satellite images
temporal_indices = torch.arange(T).unsqueeze(0).expand(B, -1)  # Example temporal indices

# Create model
model = MAEViT(embed_dim=256, num_heads=8, ff_dim=1024, num_layers=6, num_patches=256, patch_size=16, img_size=64, num_channels=10)

# Forward pass
reconstructed_images = model(images, temporal_indices)

print(f"Reconstructed Images shape: {reconstructed_images.shape}")

RuntimeError: size mismatch, got input (256), mat (256x2), vec (128)