In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import torch.nn.functional as F

# Dataset class to load the fine and coarse scale dynamics (3-index tensors)
class DynamicsDataset(Dataset):
    def __init__(self, fine_scale_images, coarse_scale_images, transform=None):
        self.fine_scale_images = fine_scale_images
        self.coarse_scale_images = coarse_scale_images
        self.transform = transform

    def __len__(self):
        return self.fine_scale_images.shape[0]  # T: number of time steps

    def __getitem__(self, idx):
        fine_img = self.fine_scale_images[idx]  # shape: (192, 192)
        coarse_img = self.coarse_scale_images[idx]  # shape: (384, 384)

        # Resize the images manually if needed (replace `transform` logic)
        if self.transform:
            fine_img = self.transform(fine_img)
            coarse_img = self.transform(coarse_img)

        return fine_img, coarse_img


# Manual resizing function (equivalent to torchvision.transforms.Resize)
def resize_image(image, size):
    return F.interpolate(image.unsqueeze(0), size=size, mode='bilinear', align_corners=False).squeeze(0)

class VisionTransformer(nn.Module):
    def __init__(self, image_size=192, patch_size=16, dim=512, depth=12, heads=8, mlp_dim=1024):
        super(VisionTransformer, self).__init__()

        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2

        # Patch Embedding: Linear projection of flattened patches
        self.patch_embed = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)
        
        # Positional Encoding (learnable)
        self.positional_encoding = nn.Parameter(torch.zeros(1, self.num_patches, dim))
        
        # Transformer Encoder layers
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim)
            for _ in range(depth)
        ])
        
        # Upsample Layer
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        # Final Convolution Layer
        self.fc = nn.Conv2d(dim, 1, kernel_size=1)  # Output one channel for the coarse image

    def forward(self, x):
        # x: (batch_size, 1, 192, 192) for fine scale input

        # Patch embedding: (batch_size, 1, 192, 192) -> (batch_size, dim, num_patches, 1) -> (batch_size, num_patches, dim)
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        
        # Add positional encoding to patches
        x = x + self.positional_encoding
        
        # Pass through the transformer encoder layers
        for layer in self.encoder_layers:
            x = layer(x)
        
        # Reshape and apply the final convolution
        x = x.transpose(1, 2).reshape(x.shape[0], -1, int(math.sqrt(self.num_patches)), int(math.sqrt(self.num_patches)))

        # First upsample (scale factor 2) to get (batch_size, 1, 48, 48)
        x = self.upsample(x)

        # Second upsample (scale factor 2) to get (batch_size, 1, 96, 96)
        x = self.upsample(x)

        # Third upsample (scale factor 2) to get (batch_size, 1, 192, 192)
        x = self.upsample(x)

        # Apply the final convolution (single channel output)
        x = self.fc(x)  # Final output layer, producing coarse scale prediction

        return x




# Training function
def train_vit(model, train_loader, num_epochs=10, lr=1e-4):
    # Loss function and optimizer
    criterion = nn.MSELoss()  # For regression
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            # Move data to device (GPU/CPU)
            inputs, targets = inputs.to(device), targets.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Compute loss
            loss = criterion(outputs, targets)
            running_loss += loss.item()

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")


# Hyperparameters
image_size = 192  # Input image size
patch_size = 16  # Patch size for ViT
dim = 512  # Dimensionality of transformer
depth = 12  # Number of transformer encoder layers
heads = 8  # Number of attention heads
mlp_dim = 1024  # Feedforward network dimension
batch_size = 8
num_epochs = 10

# Assuming your fine_scale_images and coarse_scale_images are 3D tensors:
# fine_scale_images.shape = (T, 192, 192), coarse_scale_images.shape = (T, 384, 384)
# Where T is the number of time steps

# Load your data (replace with actual data loading)
fine_scale_images = torch.randn(78, 192, 192)  # Placeholder
coarse_scale_images = torch.randn(78, 384, 384)  # Placeholder

# Reshape to (T, H, W, C) to match dataset input format (batch_size, channels, height, width)
fine_scale_images = fine_scale_images.unsqueeze(1)  # shape: (T, 1, 192, 192)
coarse_scale_images = coarse_scale_images.unsqueeze(1)  # shape: (T, 1, 384, 384)

# Create dataset and dataloaders
dataset = DynamicsDataset(fine_scale_images, coarse_scale_images, transform=lambda x: resize_image(x, (192, 192)))
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer(image_size=image_size, patch_size=patch_size, dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim).to(device)

# Train the model
train_vit(model, train_loader, num_epochs=num_epochs, lr=1e-4)

# Save the trained model
torch.save(model.state_dict(), 'vit_super_resolution_from_scratch.pth')


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (96) must match the size of tensor b (192) at non-singleton dimension 3