<a href="https://colab.research.google.com/github/beamaia/cvae-clustering/blob/main/histvae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install torch \
torchvision \
matplotlib \
numpy \
scikit-learn \
opencv-python \
openslide-python \
pandas \
Pillow \
scipy \
wandb

**Based on [HistoVAE](https://github.com/beamaia/HistoVAE/blob/main/src/models.py)**

In [2]:
#torch==2.2.1
#torchvision==0.17.1
#matplotlib==3.8.3
#numpy==1.26.4
#opencv-python==4.9.0.80
#pandas==2.2.0
#Pillow==10.2.0
#scipy==1.12.0
#wandb==0.16.3

In [3]:
import torch
from torch import nn
from math import log
from torch.nn import functional as F
import torchvision.transforms as T
import numpy as np

In [4]:
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [306]:
class VCAE(nn.Module):
    def __init__(self, inner_dim:int = 2048, dropout_rate:float = 0., device='cpu'):
        super(VCAE, self).__init__()
        self.inner_dim = inner_dim
        self.dropout_rate = dropout_rate
        self.epsilon_std = 1.0

        self.device = device

        self.encoder = self._define_encoder().to(device)
        self.decoder = self._define_decoder().to(device)

        self.z_mean = nn.Linear(128* 12* 12, self.inner_dim).to(device)
        self.z_log_var = nn.Linear(128* 12* 12, self.inner_dim).to(device)


    def _define_encoder(self, x=None):
         return nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Conv2d(16, 32, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Conv2d(32, 64, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Conv2d(64, 128, kernel_size=3),
                            nn.LeakyReLU(),
                            nn.AvgPool2d(kernel_size=2, stride=2),

                            nn.Flatten(),
                            nn.Dropout(self.dropout_rate)
                            )


    def _define_decoder(self):
        return nn.Sequential(nn.Dropout(self.dropout_rate),
                            nn.Linear(self.inner_dim, 128*12*12),
                            View((-1, 128, 12, 12)),

                            nn.Upsample(size=(24, 24)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(128, 64, kernel_size=3),

                            nn.Upsample(size=(52, 52)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(64, 32, kernel_size=3),


                            nn.Upsample(size=(109, 109)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(32, 16, kernel_size=3),

                            nn.Upsample(size=(223, 223)),
                            nn.LeakyReLU(),
                            nn.ConvTranspose2d(16, 3, kernel_size=3),
                            )

    def _resize_input(self, image):
        if image.shape[2] == 225 and image.shape[3] == 225:
            return image
        return T.Resize(size=225)(image).to(self.device)


    def encode(self, image):
        image = self._resize_input(image)
        x = self.encoder(image)
        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        return z_mean, z_log_var

    def decode(self, features):
        decoded = self.decoder(features)
        return decoded


    def sampling(self, z_mean, z_log_var):
        epsilon = torch.normal(size=(z_mean.shape[0], self.inner_dim), mean=0, std=self.epsilon_std).to(self.device)
        return z_mean + torch.exp(z_log_var) * epsilon


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z_mean, z_log_var = self.encode(x)
        z = self.sampling(z_mean, z_log_var)
        decoded = self.decoder(z)

        return z_mean, z_log_var, z, decoded



In [307]:
image = torch.randint(0, 255, (3, 3, 225, 225), dtype=torch.float32)

In [308]:
vae = VCAE(dropout_rate=0.5)

In [309]:
features = vae.encode(image)

---

In [310]:
import torchvision
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, TensorDataset

In [311]:
import pathlib as pl
from glob import glob

In [312]:
from PIL import Image

In [313]:
class PNDBDataset(Dataset):
    """P-NDB-Dataset"""

    def __init__(self, folder_path, transform=None):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.path = folder_path
        self.classes_to_idx = {
            "carcinoma": 0,
            "no_dysplasia": 1,
            "with_dysplasia": 2
        }

        self.img_transform = T.Compose([T.PILToTensor()])
        self.transform = transform
        self._organize_image()


    def _load_image(self, path):
        im = Image.open(path)
        im = self.img_transform(im)
        im = T.Resize(size=225)(im)
        return im

    def _organize_image(self):
        carcinoma_images_path = pl.Path(f"{self.path}/carcinoma")
        no_dysplasia_images_path = pl.Path(f"{self.path}/no_dysplasia")
        dysplasia_images_path = pl.Path(f"{self.path}/dysplasia")

        carcinoma_images = glob(f'{str(carcinoma_images_path)}/*.png')
        no_dysplasia_images = glob(f'{str(no_dysplasia_images_path)}/*.png')
        dysplasia_images = glob(f'{str(dysplasia_images_path)}/*.png')

        carcinoma_image_list = [self._load_image(path) for path in carcinoma_images]
        no_dysplasia_image_list = [self._load_image(path) for path in no_dysplasia_images]
        dysplasia_image_list = [self._load_image(path) for path in dysplasia_images]

        carcinoma_targets = torch.zeros(len(carcinoma_images))
        no_dysplasia_targets = torch.ones(len(no_dysplasia_images))
        dysplasia_targets = torch.ones(len(dysplasia_images)) * 2

        images_list = []
        images_list.extend(carcinoma_image_list)
        images_list.extend(no_dysplasia_image_list)
        images_list.extend(dysplasia_image_list)

        targets_list = []
        targets_list.extend(carcinoma_targets)
        targets_list.extend(no_dysplasia_targets)
        targets_list.extend(dysplasia_targets)

        self.data = torch.tensor(np.array(images_list)) / 1.
        self.target = torch.tensor(targets_list)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        target = self.target[idx]

        if self.transform:
            image = self.transform(image)

        return image, target

In [314]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [315]:
!pip install patool
import patoolib



In [316]:
#patoolib.extract_archive('/content/drive/MyDrive/dataset/data_train_test.zip')

In [317]:
learning_rate = 1e-3
batch_size = 32
epochs = 200

In [318]:
train_dataset = PNDBDataset("/content/data/train")
test_dataset = PNDBDataset("/content/data/test")



In [319]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [320]:
optimizer = torch.optim.SGD(vae.parameters(), lr=learning_rate)

In [321]:
loss_fn = F.mse_loss

In [322]:
!pip install wandb
import wandb



In [323]:
wandb.login()

True

In [324]:
alpha = 1
beta = 1

In [325]:
wandb.init(
        project="tcc",
        config={
            "epochs": epochs,
            "batch_size": batch_size,
            "lr": learning_rate,
            "optimizer": "SGD",
            "alpha": alpha,
            "beta": beta
            })

In [326]:
device = torch.device('cuda')

In [327]:
vae = VCAE(dropout_rate=0.1, device=device)

In [None]:
vae.train()
total_loss = []
n_steps_per_epoch = len(train_dataloader)
for epoch in range(epochs):
    epochs_loss = []
    print("Epoch:", epoch)
    for i, data in enumerate(train_dataloader):
        # Compute prediction and loss
        X, y = data
        X, y = X.to(device), y.to(device)
        mean, logvar, sampling, decoded = vae(X)

        kl = - 0.5 * torch.mean(1 + logvar - mean**2 - torch.exp(logvar), axis=1)
        # kl_mean = kl.mean()
        batchsize = X.shape[0]

        x_flat = X.flatten()
        decoded_flat = decoded.flatten()

        pixelwise = loss_fn(x_flat, decoded_flat, reduction='none')
        # pixelwise_mean = pixelwise.mean()


        loss = alpha * pixelwise.mean() + beta * kl.mean()
        loss = loss.mean()

        if not i % 10:
          print("pixelwise", pixelwise)
          print("kl", kl)
          print("loss", loss)
          print()

        optimizer.zero_grad()

        # Backpropagation
        loss.backward()
        optimizer.step()

        epochs_loss.append(loss.cpu().detach().numpy())
        metrics = {"train/train_loss": loss,
                    "train/epoch": (i + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                    "train/kl_div": kl,
                    "train/mse_loss": pixelwise,
                    "train/epoch": epoch + 1}

        wandb.log(metrics)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        6.7822, 6.6103, 4.2818, 4.9455, 5.0112], device='cuda:0',
       grad_fn=<MulBackward0>)
loss tensor(25597.3672, device='cuda:0', grad_fn=<MeanBackward0>)

pixelwise tensor([49210.1797, 46153.6914, 46146.2266,  ..., 48861.2070, 50353.9453,
        50270.8984], device='cuda:0', grad_fn=<MseLossBackward0>)
kl tensor([ 4.4190, 16.1393,  4.0533,  6.3415,  3.8422,  2.7662,  4.8963,  4.6539,
         5.2206,  5.4804,  8.0233,  5.6711,  4.5247,  7.8038,  8.9341,  4.2130,
         5.8892,  9.1157,  8.3484,  5.7546,  5.9582,  3.1322,  5.2301,  3.7363,
         5.8131,  9.7197,  6.7851,  6.7337,  2.5555,  3.8762,  5.8650,  8.2834],
       device='cuda:0', grad_fn=<MulBackward0>)
loss tensor(26503.1621, device='cuda:0', grad_fn=<MeanBackward0>)

pixelwise tensor([32694.7988, 33062.9922, 32332.5508,  ..., 19897.2852, 21063.3965,
        21053.4707], device='cuda:0', grad_fn=<MseLossBackward0>)
kl tensor([5.1198, 8.3524, 4.260

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(test_dataset[50][0].T)
plt.show()

In [None]:
train_dataset.classes_to_idx

In [None]:
x_output = vae(test_dataset[50][0].unsqueeze(0).to(device))

In [None]:
decoded_image = x_output[3].cpu().detach()[0]

In [None]:
decoded_image = decoded_image.type(torch.uint8)

In [None]:
decoded_image

In [None]:
imgplot = plt.imshow(decoded_image.T)
plt.show()

In [None]:
wandb.finish()