In [1]:
import torch

# Constants
model_size = 1.3 * (1024 ** 2)  # Model size in bytes
input_size_per_image = 22 * (1024 ** 2)  # Size of one image in bytes
feature_size = 50 * (1024 ** 2)  # Size of features in bytes
available_memory = 24 * (1024 ** 3)  # Available GPU memory in bytes

# Estimate the overhead (2-3 times the model and input size)
overhead_factor = 3

# Finding maximum batch size
max_batch_size = 1
while True:
    input_size = max_batch_size * (input_size_per_image * 2)  # For two images
    features_size = max_batch_size * feature_size  # If applicable
    total_memory = (model_size + input_size + features_size) * overhead_factor
    
    if total_memory > available_memory:
        break
    max_batch_size += 1

print(f'Maximum batch size: {max_batch_size - 1}')


Maximum batch size: 87


In [31]:
import torch

# Model constants
model_size_mb = 1.3 * 1024  # ViT-B MAE model size in MB
feature_dim = 768  # ViT-B feature dimension 
window_size = 224
overlap = 56
patch_size = 16

# Image parameters (typical satellite image)
img_height = 2024 
img_width = 2024
img_channels = 3

# Calculate windows per image
stride = window_size - overlap
windows_h = max(1, (img_height - window_size) // stride + 1)
windows_w = max(1, (img_width - window_size) // stride + 1)
total_windows = windows_h * windows_w

print(f"Image dimensions: {img_height}x{img_width}")
print(f"Windows per dimension: {windows_h}x{windows_w} = {total_windows} total windows")

# Image memory requirements
image_size_mb = 2 * img_height * img_width * img_channels * 4 / (1024*1024)  # Two images (before/after), 4 bytes per float

# Window processing (peak during feature extraction)
inference_window_batch_size = 32  # Adjust based on implementation
peak_window_batch_mb = window_size * window_size * img_channels * 4 * inference_window_batch_size / (1024*1024)

# Feature extraction memory
feature_tokens_per_window = (window_size // patch_size) ** 2 + 1  # +1 for CLS token
feature_size_per_window_mb = feature_tokens_per_window * feature_dim * 4 / (1024*1024)  # 4 bytes per float
feature_map_mb = feature_size_per_window_mb * total_windows  # Dictionary of all window features

# Final merged feature representation (only active tokens in circle)
feature_grid_h = img_height // patch_size
feature_grid_w = img_width // patch_size
total_grid_size = feature_grid_h * feature_grid_w
circle_ratio = 3.14159 / 4  # π/4 (ratio of circle to square)
active_tokens = int(total_grid_size * circle_ratio) + 1  # +1 for CLS
final_feature_mb = active_tokens * feature_dim * 4 / (1024*1024)

# Total memory per sample (accounting for both before/after images)
# During processing we need: raw images + peak of either window batch or feature map + final representation
processing_memory_mb = image_size_mb + max(peak_window_batch_mb, feature_map_mb) + 2 * final_feature_mb

# Training specific memory (gradients, optimizer states)
training_overhead_factor = 3  # Higher for training due to gradients
training_memory_mb = processing_memory_mb * training_overhead_factor

# Calculate max batch size based on available GPU memory
available_memory_mb = 24 * 1024  # 24 GB GPU
max_batch_size = max(1, int((available_memory_mb - model_size_mb) / training_memory_mb))

# Consider gradient accumulation
gradient_accumulation_steps = 4
effective_batch_size = max_batch_size * gradient_accumulation_steps

print(f"\nMemory requirements:")
print(f"  Raw images: {image_size_mb:.2f} MB")
print(f"  Feature map dictionary: {feature_map_mb:.2f} MB")
print(f"  Peak window batch: {peak_window_batch_mb:.2f} MB")
print(f"  Final feature representation: {2 * final_feature_mb:.2f} MB (both images)")
print(f"  Total processing memory: {processing_memory_mb:.2f} MB")
print(f"  Total training memory: {training_memory_mb:.2f} MB (with gradients)")
print(f"\nBatch size calculations:")
print(f"  Model size: {model_size_mb:.2f} MB")
print(f"  Available GPU memory: {available_memory_mb:.2f} MB")
print(f"  Maximum batch size: {max_batch_size}")
print(f"  Effective batch size with gradient accumulation: {effective_batch_size}")

Image dimensions: 2024x2024
Windows per dimension: 11x11 = 121 total windows

Memory requirements:
  Raw images: 93.76 MB
  Feature map dictionary: 69.83 MB
  Peak window batch: 18.38 MB
  Final feature representation: 73.06 MB (both images)
  Total processing memory: 236.66 MB
  Total training memory: 709.98 MB (with gradients)

Batch size calculations:
  Model size: 1331.20 MB
  Available GPU memory: 24576.00 MB
  Maximum batch size: 32
  Effective batch size with gradient accumulation: 128


In [17]:
memory_per_sample

838.986328125

In [27]:
H=2017
W=2028
window_size=16
window_n1 = 0
h_center, w_center = H/2, W/2

for h in range(0, H - window_size + 1, window_size):
        for w in range(0, W - window_size + 1, window_size):
            window_n1 += 1
window_n1

15876

In [28]:
# Calculate circle parameters
H=2017
W=2028
window_size=16
window_n2 = 0
h_center, w_center = H/2, W/2
radius = min(H, W)/2 - 1  # Slightly smaller than half the image dimension

# Create batches of windows, but only include windows that intersect with the circle
for h in range(0, H - window_size + 1, window_size):
    for w in range(0, W - window_size + 1, window_size):
        # Calculate window center
        window_center_h = h + window_size/2
        window_center_w = w + window_size/2
        
        # Calculate distance from window center to image center
        dist_to_center = ((window_center_h - h_center)**2 + (window_center_w - w_center)**2)**0.5
        
        # Only process windows that are at least partially within the circle
        window_radius = window_size/2 * 1.414  # Diagonal radius of window (√2 × half_size)
        if dist_to_center <= radius + window_radius:
            window_n2 += 1
window_n2

12713

In [29]:
window_n2/window_n1

0.8007684555303602