In [1]:
import openslide as ops
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import pickle
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [2]:
DATASET_DIR = "../datasets/merged_embeddings/merged_dataset.pkl"
SLIDE_DIR = "../datasets/wsi"

In [3]:
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.slide_cache = {}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        slide_name = sample["slide_name"]
        x = sample["x"]
        y = sample["y"]
        patch_size = sample["patch_size"]

        if slide_name not in self.slide_cache:
            self.slide_cache[slide_name] = ops.OpenSlide(f"{SLIDE_DIR}/{slide_name}")

        slide = self.slide_cache[slide_name]
        image = slide.read_region((x, y), 0, patch_size)
        image = image.convert("RGB")

        embedding = sample["embedding_vector"]
        if self.transform:
            image = self.transform(image)
        return image, embedding

In [None]:
with open(DATASET_DIR, "rb") as f:
    train_dataset = pickle.load(f)

idx =100
print("Length of train dataset: ", len(train_dataset))
print("Train dataset keys: ", train_dataset[idx].keys())
print("Slide name: ", train_dataset[idx]["slide_name"])
print("Embedding vector shape: ", train_dataset[idx]["embedding_vector"].shape)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224))
])

train_data = CustomDataset(train_dataset, transform=transform)
print("Length of train data: ", len(train_data))

In [5]:
class EmbeddingToImageDecoder(nn.Module):
    def __init__(self, embedding_size=384):
        super(EmbeddingToImageDecoder, self).__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(embedding_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.Linear(4096, 8192),
            nn.ReLU(),
            nn.Linear(8192, 7 * 7 * 256),
            nn.ReLU()
        )
        
        self.upconv = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.mlp(x)
        x = x.view(x.size(0), 256, 7, 7)
        x = self.upconv(x)
        return x


In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = self.relu(out)
        return out

class ViTDecoderWithResiduals(nn.Module):
    def __init__(self, embedding_dim=384):
        super(ViTDecoderWithResiduals, self).__init__()
        
        # MLP: Embedding'den özellik haritasına geçiş
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, 7 * 7 * 256),
            nn.ReLU()
        )
        
        # Yukarı örnekleme ve residual block'lar
        self.upconv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8 → 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.residual1 = ResidualBlock(128, 128)
        
        self.upconv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 16x16 → 32x32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        self.residual2 = ResidualBlock(64, 64)
        
        self.upconv3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x32 → 64x64
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2)
        )
        self.residual3 = ResidualBlock(32, 32)
        
        self.upconv4 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 64x64 → 128x128
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2)
        )
        self.residual4 = ResidualBlock(16, 16)
        
        # Son katman
        self.upconv5 = nn.Sequential(
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),  # 128x128 → 224x224
        )
    
    def forward(self, z):
        x = self.mlp(z)
        x = x.view(x.size(0), 256, 7, 7)  # Görüntü formatına çevir
        
        x = self.upconv1(x)
        x = self.residual1(x)
        
        x = self.upconv2(x)
        x = self.residual2(x)
        
        x = self.upconv3(x)
        x = self.residual3(x)
        
        x = self.upconv4(x)
        x = self.residual4(x)
        
        x = self.upconv5(x)
        x = torch.sigmoid(x)
        return x


In [7]:
def train_model(train_loader, model, optimizer, criterion, device, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        for image, embedding in train_loader:
            embedding = embedding.unsqueeze(1)
            embedding = embedding.to(device)
            image = image.to(device)
            optimizer.zero_grad()
            output = model(embedding)
            loss = criterion(output, image)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}, Train Loss: {loss.item()}")

In [None]:
model = ViTDecoderWithResiduals()
model.train()
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model.to(device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-5)
criterion = nn.MSELoss()
batch_size = 512
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
train_model(train_loader, model, optimizer, criterion, device, num_epochs=10)

In [10]:
torch.save(model.state_dict(), "decoder_with_residuals.pth")

In [None]:
# load state dict
model.load_state_dict(torch.load("decoder_with_residuals.pth"))
model.to(device)

In [None]:
# test model
model.eval()
fig,ax = plt.subplots(1, 2, figsize=(10, 5))
for data in train_dataset[11110:11110]:
    embedding = data["embedding_vector"]
    embedding = torch.from_numpy(embedding).float().unsqueeze(0)
    embedding = embedding.to(device)
    with torch.no_grad():
        output = model(embedding)
    output = output.squeeze(0).cpu().numpy()
    output = np.moveaxis(output, 0, -1)

    x = data["x"]
    y = data["y"]
    size = data["resize"]
    slide_name = data["slide_name"]
    slide = ops.OpenSlide(f"{SLIDE_DIR}/{slide_name}")
    image = slide.read_region((x, y), 0, size)
    image = image.convert("RGB")

    ax[0].imshow(image)
    ax[1].imshow(output)
    plt.show()
    plt.close()
    slide.close()
