In [1]:
from functools import partial

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar

from models.data import JSRTDataModule, CheXpertDataModule
from models.mae.mae import ViTAE
from models.dinosaur import DINOSAUR

from models.encoders import CNNEncoder, ResNet34_8x8, get_resnet34_encoder

In [2]:
cuda_device = 'cuda:1'

In [3]:
saved_model = ViTAE.load_from_checkpoint('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/bash_scripts/lightning_logs/chestxray_mae/chestxray_mae/gn9nzdz8/checkpoints/epoch=503-step=95256.ckpt',
    model_kwargs={
        'img_size': 224,
        'embed_dim': 768,
        'in_chans': 1,
        'num_heads': 12,
        'depth': 12,
        'decoder_embed_dim': 512,
        'decoder_depth': 8,
        'decoder_num_heads': 16,
        'norm_layer': partial(nn.LayerNorm, eps=1e-6),
        'mlp_ratio': 4.0,
        'patch_size': 16,
        'norm_pix_loss': False,
    },
    learning_rate=1e-4,
    map_location=torch.device('cpu'),
    )

saved_model2 = ViTAE.load_from_checkpoint('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/bash_scripts/lightning_logs/chestxray_mae/chestxray_mae/gn9nzdz8/checkpoints/epoch=503-step=95256.ckpt',
    model_kwargs={
        'img_size': 224,
        'embed_dim': 768,
        'in_chans': 1,
        'num_heads': 12,
        'depth': 12,
        'decoder_embed_dim': 512,
        'decoder_depth': 8,
        'decoder_num_heads': 16,
        'norm_layer': partial(nn.LayerNorm, eps=1e-6),
        'mlp_ratio': 4.0,
        'patch_size': 16,
        'norm_pix_loss': False,
    },
    learning_rate=1e-4,
    map_location=torch.device('cpu'),
    )

#cnnencoder = CNNEncoder(slot_dim=128, num_channels=1)
resnet34_8x8 = ResNet34_8x8(get_resnet34_encoder())         # output is shape (batch_size, 256, 8, 8)

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torchmetrics.functional import dice
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR
from pytorch_lightning import LightningModule
import wandb

from utils.utils import SoftPositionEmbed, spatial_broadcast, unstack_and_split
from models.slot_attention import ProbabalisticSlotAttention, FixedSlotAttention, SlotAttention
from models.decoders import SlotSpecificDecoder, Decoder, MlpDecoder

class DINOSAUR(LightningModule):
    def __init__(self, frozen_encoder, trainable_encoder, num_slots, num_iterations, num_classes, slot_dim=128, task='recon', include_seg_loss=False, probabilistic_slots=True, 
                learning_rate=1e-3, hidden_decoder_dim=2048, temperature=1, lr_warmup=True, log_images=True):
        super(DINOSAUR, self).__init__()

        self.frozen_encoder = frozen_encoder.model
        self.trainable_encoder = trainable_encoder.model
        self.encoder_pos_embeddings = SoftPositionEmbed(slot_dim, (14, 14))
        if probabilistic_slots:
            self.slot_attention = ProbabalisticSlotAttention(num_slots=num_slots, dim=slot_dim, num_iterations=num_iterations, temperature=temperature)
        else:
            self.slot_attention = SlotAttention(num_slots, slot_dim, num_iterations, temperature=temperature)
            #FixedSlotAttention(num_slots=num_slots, dim=slot_dim, num_iterations=num_iterations, temperature=temperature)
        
        # four-layer mlp
        self.mlp_decoder = MlpDecoder(slot_dim, 768, 14*14, hidden_features=hidden_decoder_dim)
        # nn.Sequential(
        #     nn.Linear(slot_dim, hidden_decoder_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_decoder_dim, hidden_decoder_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_decoder_dim, hidden_decoder_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_decoder_dim, 768 + 1) # extra one for alpha masks
        # )

        self.mlp = nn.Sequential(
            nn.Linear(768, slot_dim),
            nn.ReLU(),
            nn.Linear(slot_dim, slot_dim),
        )

        self.decoder_pos_embeddings = SoftPositionEmbed(slot_dim, (14, 14))
        self.num_classes = num_classes
        self.task = task
        self.learning_rate = learning_rate
        self.lr_warmup = lr_warmup
        self.embedding_norm = nn.LayerNorm(768)
        self.embedding_norm_decoder = nn.LayerNorm(slot_dim)

        self.include_seg_loss = include_seg_loss

        self.attn = None
        self.masks = None
        self.recons = None
        self.preds = None

        for param in self.frozen_encoder.parameters():
            param.requires_grad = False
  
        self.mlp_decoder.apply(self.init_weights)
        self.mlp.apply(self.init_weights)
        self.slot_attention.apply(self.init_weights)

        self.test_predmaps = []
        self.test_probmaps = []

        # for logging
        self.log_images = log_images
        if log_images:
            self.save_hyperparameters(ignore=['encoder'])

        self.train_imgs = None
        self.train_preds = None
        self.val_imgs = None
        self.val_preds = None
        self.logged_train_images_this_epoch = False
        self.logged_val_images_this_epoch = False

    def forward(self, x):
        batch_size = x.size(0)
        with torch.no_grad():
            patch_embeddings_frozen, _, _ = self.frozen_encoder.forward_encoder(x, 0.0)
            patch_embeddings_frozen = patch_embeddings_frozen[:, 1:, :]     # exclude cls token

        # generate patch features from trainable encoder
        patch_embeddings_train, _, _ = self.trainable_encoder.forward_encoder(x, 0.0)
        patch_embeddings_train = patch_embeddings_train[:, 1:, :]           # exclude cls token

        # intermediate mlp between encoder and slot attention
        patch_embeddings_train = self.embedding_norm(patch_embeddings_train)
        patch_embeddings_train = self.mlp(patch_embeddings_train)                       # shape (batch_size*num_patches, embed_dim)

        # apply slot attention to trainable patch features
        slots, slot_attn = self.slot_attention(patch_embeddings_train)            # slots: shape (batch_size, num_slots, slot_dim)
    
        # broadcast each slot to N x N grid (N=14 for 224x224 images)
        #x = spatial_broadcast(slots, (14, 14))                                  # shape (batch_size*num_slots, width_init, height_init, slot_dim)

        # decode each slot to a mask
        #x = self.decoder_pos_embeddings(x)

        #x = x.view(batch_size, slots.shape[1], 14*14, -1)

        #x = self.embedding_norm_decoder(x)
        recons, masks = self.mlp_decoder(slots)#.to(patch_embeddings_train.device)               # shape (batch_size, num_classes, H, W, slot_dim + 1)

        # decoded = x[:, :, :, :-1]                                               # shape (batch_size, num_slots, num_patches, slot_dim)
        # masks = x[:, :, :, -1]                                                  # shape (batch_size, num_slots, num_patches)
        # masks = torch.softmax(masks, dim=1)                                     # softmax over 
        # log masks
        self.masks = masks[0, ...]
        #print(masks.shape)
        #masks = masks.unsqueeze(-1)                                             # shape (batch_size, num_slots, num_patches, slot_dim)
        #print(masks.shape)
        #recons = torch.sum(decoded * masks, dim=1)

        preds = torch.argmax(masks, dim=1)

        return patch_embeddings_frozen, recons, masks, preds, slot_attn
    
    def init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=1e-6)
        if not self.lr_warmup:
            return optimizer

        scheduler = self.warmup_lr_scheduler(optimizer, 10000, 100000, 1e-6, self.learning_rate)
        return {
            "optimizer": optimizer,
             "lr_scheduler": {
                 "scheduler": scheduler,
            #     #"monitor": "train_loss",  # metric to monitor
                 "frequency": 1,  
                 "interval": "step",
            #     #"strict": True,
             },
        }
    
    def warmup_lr_scheduler(self, optimizer, warmup_steps, decay_steps, start_lr, target_lr):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(max(1, warmup_steps)) * (target_lr - start_lr) / target_lr + start_lr / target_lr
            else:
                return 0.5 ** ((current_step - warmup_steps) / decay_steps)
        return LambdaLR(optimizer, lr_lambda)

    def process_batch(self, batch, batch_idx):
        #x, y = batch['image'], batch['labelmap']
        x = batch['image']
        # generate patch features from frozen encoder
        targets, recons, masks, preds, attn = self(x)

        # log masks and attn
        self.attn = attn[0, ...]
        self.recons = recons[0, ...]
        self.preds = preds[0, ...]

        # calculate loss
        loss = torch.cdist(targets, recons, p=2).mean()
        dsc = 1#dice(preds, y.squeeze(), average='macro', num_classes=self.num_classes, ignore_index=0)
        
        return loss, dsc, preds, x

    def training_step(self, batch, batch_idx):
        loss, dsc, preds, imgs = self.process_batch(batch, batch_idx)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_dice", dsc, prog_bar=True)

        if self.train_imgs is None:
            self.train_imgs = imgs[:5].cpu()
            self.train_preds = preds[:5].cpu()

        return loss

    def validation_step(self, batch, batch_idx):
        loss, dsc, preds, imgs = self.process_batch(batch, batch_idx)
        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.log("val_dice", dsc, prog_bar=True, sync_dist=True)

        if self.val_imgs is None:
            self.val_imgs = imgs[:5].cpu()
            self.val_preds = preds[:5].cpu()

    def on_test_start(self):
        self.test_probmaps = []
        self.test_predmaps = []

    def test_step(self, batch, batch_idx):
        loss, dsc, probs, preds = self.process_batch(batch, batch_idx)
        self.log("test_loss", loss)
        self.log("test_dice", dsc)
        self.test_probmaps.append(probs)
        self.test_predmaps.append(preds)

    def on_train_epoch_end(self):
        if self.log_images:
            if not self.logged_train_images_this_epoch and self.train_imgs is not None:
                self._log_train_images(self.train_imgs, self.train_preds)
                self.logged_train_images_this_epoch = True
            
            self.train_imgs = None
            self.train_preds = None
            self.logged_train_images_this_epoch = False

    def on_train_batch_end(self, outputs, batch, batch_idx):
        lr = self.optimizers().param_groups[0]['lr']
        if self.log_images:
            self.logger.experiment.log({"learning_rate": lr}, commit=False)

    def on_validation_epoch_end(self):
        if self.log_images:
            if not self.logged_val_images_this_epoch and self.val_imgs is not None:
                self._log_val_images(self.val_imgs, self.val_preds)
                self.logged_val_images_this_epoch = True
            
            self.val_imgs = None
            self.val_preds = None
            self.logged_val_images_this_epoch = False

            self._log_key_images()

    def test_step(self, batch, batch_idx):
        pass

    def _log_val_images(self, imgs: torch.Tensor, preds: torch.Tensor):
        grid = make_grid(imgs)
        #grid_val = make_grid(preds).unsqueeze(1).float()
        self.logger.experiment.log({
            "Validation Images": [
                wandb.Image(grid, caption="Original Images"),
                #wandb.Image(grid_val, caption="Features")
            ],
        })

    def _log_train_images(self, imgs: torch.Tensor, preds: torch.Tensor):
        grid = make_grid(imgs)
        #grid_val = make_grid(preds).unsqueeze(1).float()
        self.logger.experiment.log({
            "Train Images": [
                wandb.Image(grid, caption="Original Images"),
                #wandb.Image(grid_val, caption="Features")
            ],
        })

    def _log_key_images(self):
        attn, masks = self.attn, self.masks
        attn_maps = make_grid(attn.reshape(-1, 14, 14)).unsqueeze(1)
        masks = make_grid(masks.reshape(-1, 14, 14)).unsqueeze(1)
      
        self.logger.experiment.log({
            "Slot Attention Maps": [
                wandb.Image(attn_maps, caption="Attention Maps"),
            ],
        })
        self.logger.experiment.log({
            "Masks": [
                wandb.Image(masks, caption="Alpha Masks"),
            ],
        })

In [30]:
#saved_model.eval()
oss = DINOSAUR(saved_model, saved_model2, num_slots=8, num_iterations=3, num_classes=4, slot_dim=256, task='recon',
                                 learning_rate=4e-4, temperature=1, log_images=False, lr_warmup=True,
                                 probabilistic_slots=False)

In [31]:
#data = JSRTDataModule(data_dir='./data/JSRT/', batch_size=32, augmentation=True)
data = CheXpertDataModule(data_dir='/vol/biodata/data/chest_xray/CheXpert-v1.0/preproc_224x224/', batch_size=32, cache=False)

Loading Data: 100%|██████████| 96609/96609 [00:00<00:00, 1883669.81it/s]
Loading Data: 100%|██████████| 5085/5085 [00:00<00:00, 659738.80it/s]
Loading Data: 100%|██████████| 25424/25424 [00:00<00:00, 793824.14it/s]


In [32]:
oss.log_images = True

In [33]:
from pytorch_lightning.loggers import WandbLogger
from pathlib import Path
wandb_logger = WandbLogger(save_dir='./runs/lightning_logs/dinosaur_recons/', project='dinosaur_recons')
output_dir = Path(f"dinosaur_recons/run_{wandb_logger.experiment.id}")  # type: ignore
print("Saving to" + str(output_dir.absolute()))

trainer = Trainer(
    #max_epochs=5000,
    max_steps=500000,
    precision='16-mixed',
    accelerator='auto',
    devices=[0],
    #strategy='ddp_notebook',
    # log_every_n_steps=250,
    val_check_interval=0.25,
    #check_val_every_n_epoch=50,
    # #save_top_k=1,
    logger=wandb_logger,
    # callbacks=[ModelCheckpoint(monitor="val_loss", mode='min'), TQDMProgressBar(refresh_rate=100)],
)
torch.set_float32_matmul_precision('medium')

trainer.fit(model=oss, datamodule=data)

# trainer.validate(model=oss, datamodule=data, ckpt_path=trainer.checkpoint_callback.best_model_path)

# trainer.test(model=oss, datamodule=data, ckpt_path=trainer.checkpoint_callback.best_model_path)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Saving to/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/dinosaur_recons/run_xg1xvjgr



  | Name                   | Type              | Params
-------------------------------------------------------------
0 | frozen_encoder         | VisionTransformer | 111 M 
1 | trainable_encoder      | VisionTransformer | 111 M 
2 | encoder_pos_embeddings | SoftPositionEmbed | 1.3 K 
3 | slot_attention         | SlotAttention     | 724 K 
4 | mlp_decoder            | MlpDecoder        | 10.5 M
5 | mlp                    | Sequential        | 262 K 
6 | decoder_pos_embeddings | SoftPositionEmbed | 1.3 K 
7 | embedding_norm         | LayerNorm         | 1.5 K 
8 | embedding_norm_decoder | LayerNorm         | 512   
-------------------------------------------------------------
122 M     Trainable params
111 M     Non-trainable params
234 M     Total params
936.164   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


Validation: |          | 0/? [00:00<?, ?it/s]

wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr not found during createRunFiles (<Response [404]>)
wandb: ERROR Error while calling W&B API: run dinosaur_recons/xg1xvjgr no

BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7d97d4139900>> (for post_run_cell), with arguments args (<ExecutionResult object at 7d9834d75570, execution_count=33 error_before_exec=None error_in_exec=[Errno 32] Broken pipe info=<ExecutionInfo object at 7d9834d75b70, raw_cell="from pytorch_lightning.loggers import WandbLogger
.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bdeepmedic3.doc.ic.ac.uk/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/experiment_dinosaur.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
image_num = 0
batch = next(iter(data.test_dataloader()))
# vertical flip batch
# batch['image'] = torch.flip(batch['image'], [2])
# batch['labelmap'] = torch.flip(batch['labelmap'], [2])

batch['image'] = batch['image'][:2]
batch['labelmap'] = batch['labelmap'][:2]
# move to gpu
# batch = {k: v.to(cuda_device) for k, v in batch.items()}
# oss.to(cuda_device)

with torch.no_grad():
    loss, dsc, probs, preds, _ = oss.process_batch(batch, 1)
image = batch['image'][image_num].squeeze()
labelmap = batch['labelmap'][image_num].squeeze()
probmap = torch.max(probs.cpu(), dim=1, keepdim=True)[0].squeeze().detach().numpy()
predmap = preds[image_num].squeeze()

f, ax = plt.subplots(1,4, figsize=(15, 15))

ax[0].imshow(image, cmap=matplotlib.cm.gray)
ax[0].axis('off')
ax[0].set_title('image')

ax[1].imshow(labelmap, cmap=matplotlib.cm.gray)
ax[1].axis('off')
ax[1].set_title('labelmap')

ax[2].imshow(predmap, cmap=matplotlib.cm.gray)
ax[2].axis('off')
ax[2].set_title('prediction')

ax[3].imshow(probmap[image_num, ...], cmap='plasma')
ax[3].axis('off')
ax[3].set_title('probability map')

image_num += 1
image = batch['image'][image_num].squeeze()
labelmap = batch['labelmap'][image_num].squeeze()
probmap = torch.max(probs.cpu(), dim=1, keepdim=True)[0].squeeze().detach().numpy()
predmap = preds[image_num].squeeze()

f, ax = plt.subplots(1,4, figsize=(15, 15))

ax[0].imshow(image, cmap=matplotlib.cm.gray)
ax[0].axis('off')
ax[0].set_title('image')

ax[1].imshow(labelmap, cmap=matplotlib.cm.gray)
ax[1].axis('off')
ax[1].set_title('labelmap')

ax[2].imshow(predmap, cmap=matplotlib.cm.gray)
ax[2].axis('off')
ax[2].set_title('prediction')

ax[3].imshow(probmap[image_num, ...], cmap='plasma')
ax[3].axis('off')
ax[3].set_title('probability map')

In [None]:
# plot attention matrix
image_num = 0
attn = oss.attn[image_num].cpu().detach().numpy()   
# slot 0
# slot_0_attn = attn[0, :]
# slot_0_attn = slot_0_attn.reshape(14, 14)

f, ax = plt.subplots(1, attn.shape[0], figsize=(15, 15))

for slot in range(attn.shape[0]):
    slot_attn = attn[slot, :]
    slot_attn = slot_attn.reshape(14, 14)
    ax[slot].imshow(slot_attn, cmap=matplotlib.cm.gray)
    ax[slot].axis('off')
    ax[slot].set_title(f'slot {slot}')

In [None]:
# calculate entropy of predictions
image_num = 0
trials = 5
batch = next(iter(data.test_dataloader()))

list_entropy = []
#all_preds = torch.empty((batch['image'].shape[0], 224, 224), dtype=torch.long)
oss.to(cuda_device)
batch = {k: v.to(cuda_device) for k, v in batch.items()}
for _ in range(trials):
    with torch.no_grad():
        loss, dsc, probs, preds, _ = oss.process_batch(batch)
        list_entropy.append(preds)

all_preds = torch.stack(list_entropy, dim=1)
oss.to('cpu')
del batch
print(all_preds.shape)

In [None]:
def calculate_class_entropy(predictions):
    """
    Calculate entropy of class predictions for each pixel.
    
    Args:
    predictions: Tensor of shape (N, H, W) where N is the number of predictions,
                 H and W are height and width. Values are class indices.
    
    Returns:
    entropy: Tensor of shape (H, W) containing entropy for each pixel
    """
    N, H, W = predictions.shape
    num_classes = predictions.max().item() + 1  # Assuming class indices start from 0
    
    # Create one-hot encoding
    one_hot = torch.zeros(N, num_classes, H, W, device=predictions.device)
    one_hot.scatter_(1, predictions.unsqueeze(1), 1)
    
    # Sum over the N dimension to get class counts
    class_counts = one_hot.sum(dim=0)  # Shape: (num_classes, H, W)
    
    # Calculate probabilities
    probabilities = class_counts / N
    
    # Add a small epsilon to avoid log(0)
    epsilon = 1e-7
    probabilities = torch.clamp(probabilities, epsilon, 1 - epsilon)
    
    # Calculate entropy
    entropy = -torch.sum(probabilities * torch.log2(probabilities), dim=0)
    
    return entropy

# Example usage
# Assuming 'stacked_predictions' is your tensor of shape (N, H, W)
image_num = 7
pixel_entropy = calculate_class_entropy(all_preds[image_num, ...]).cpu()

f, ax = plt.subplots(1, 1, figsize=(5, 5))

ax.imshow(pixel_entropy, cmap='plasma')
ax.axis('off')
ax.set_title('image')