In [3]:
import torch
from torch import nn
import numpy as np

def divide_image_into_patches(image, patch_size):
    # Assuming image is a 2D array and patch_size is a tuple (patch_height, patch_width)
    patches = []
    img_height, img_width = image.shape
    patch_height, patch_width = patch_size
    
    for i in range(0, img_height, patch_height):
        row_patches = []
        for j in range(0, img_width, patch_width):
            patch = image[i:i + patch_height, j:j + patch_width]
            row_patches.append(patch)
        patches.append(row_patches)
    
    return patches

def get_patch_embedding(patch):
    # Placeholder function for computing embedding for a patch
    # In practice, this should return a 1x768 tensor
    return torch.randn(1, 768)

def create_embedding_structure(image, patch_size):
    patches = divide_image_into_patches(image, patch_size)
    M = len(patches)  # Number of rows of patches
    N = len(patches[0])  # Number of columns of patches
    embedding_dim = 768
    
    # Create an empty tensor to store the embeddings
    embeddings = torch.zeros((M, N, embedding_dim))
    
    for i in range(M):
        for j in range(N):
            patch = patches[i][j]
            embedding = get_patch_embedding(patch)
            embeddings[i, j] = embedding
    
    return embeddings

# Example usage
image = np.random.rand(1024, 1024)  # Example image
patch_size = (256, 256)  # Example patch size

embedding_structure = create_embedding_structure(image, patch_size)
print(embedding_structure.shape)  # Should be (number of patches vertically, number of patches horizontally, 768)


torch.Size([4, 4, 768])


In [4]:
class EmbedSubtypeClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(EmbedSubtypeClassifier, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=768, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * (input_dim[0] // 2) * (input_dim[1] // 2), 256),  # Adjust based on pooling
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
            nn.Sigmoid()  # For binary classification
        )

    def forward(self, x):
        # x should have shape (batch_size, 768, H, W)
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc_layers(x)

# Example usage
model = EmbedSubtypeClassifier(input_dim=(4, 4), output_dim=1)  # Adjust input_dim based on number of patches
input_tensor = torch.randn(32, 768, 4, 4)  # Batch size of 32, 768 channels, 4x4 patches
output = model(input_tensor)
print(output.shape)  # Should be (32, 1)


torch.Size([32, 1])
