In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from timm.models.vision_transformer import PatchEmbed, Block
from pytorch_lightning.strategies import DDPStrategy

from data import JSRTDataModule

from medmnist.info import INFO
from medmnist.dataset import MedMNIST
from torch.utils.data import DataLoader

In [2]:
def create_patch_sequence(x):
        patch_height, patch_width = 16, 16

        # unfold 2 extracts patches from height dimension while unfold 3 extracts patches from width dimension
        patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
        # Reshape patches to the desired format
        patches = patches.contiguous().view(x.size(0), x.size(1), -1, patch_height, patch_width)
        
        # Permute to have shape (batch_size, num_patches, channels, patch_height, patch_width)
        patches = patches.permute(0, 2, 1, 3, 4)

        return patches

def create_masks(patches, mask_ratio=0.75):
        """mask specified proportion of patches"""
        device = patches.device
        patches = patches.view(patches.shape[0], patches.shape[1], -1)
        batch_size, num_patches, _, = patches.shape
        num_masked_tokens = int(num_patches * mask_ratio)

        # take num_masks random indices
        random_values = torch.rand(batch_size, num_patches)

        indices = torch.argsort(random_values, dim=1)
        
        mask_indices = indices[:, num_patches - num_masked_tokens:]

        mask = torch.zeros(batch_size, num_patches, dtype=torch.bool, device=device)
        mask.scatter_(1, mask_indices, True)

        masked_patches = patches.clone()
        masked_patches[mask] = 0
        
        ids_reverse = torch.argsort(indices, dim=1)

        return masked_patches, ids_reverse, num_masked_tokens, mask

def reverse_patch_sequence(x):

        # fold patches back to image
        c = 1
        p, q = 16, 16
        h, w = 224, 224
        n = x.shape[0]
        x = x.reshape(shape=(n, h // p, w // q, p, q, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(n, c, h, w))

        return imgs

In [3]:
# class ChestMNISTDataset(MedMNIST):
#     def __init__(self, split = 'train'):
#         ''' Dataset class for PneumoniaMNIST.
#         The provided init function will automatically download the necessary
#         files at the first class initialistion.

#         :param split: 'train', 'val' or 'test', select subset

#         '''
#         self.flag = "chestmnist"
#         self.size = 28
#         self.size_flag = ""
#         self.root = './data/chestmnist/'
#         self.info = INFO[self.flag]
#         self.download()

#         npz_file = np.load(os.path.join(self.root, "chestmnist.npz"))

#         self.split = split

#         # Load all the images
#         assert self.split in ['train','val','test']

#         self.imgs = npz_file[f'{self.split}_images']
#         self.labels = npz_file[f'{self.split}_labels']

#     def __len__(self):
#         return self.imgs.shape[0]

#     def __getitem__(self, index):
#         # TASK: Fill in the blanks such that you return two tensors
#         # of shape [1, 28, 28], img_view1 and img_view2, representing two augmented view of the images.

#         image = torch.tensor(self.imgs[index]).unsqueeze(0)

#         return image

In [4]:
# class ChestMNISTDataModule(LightningDataModule):
#     def __init__(self, batch_size: int = 8):
#         super().__init__()
#         self.batch_size = batch_size
#         self.train_set = ChestMNISTDataset(split='train')
#         self.val_set = ChestMNISTDataset(split='val')
#         self.test_set = ChestMNISTDataset(split='test')

#     def train_dataloader(self):
#         return DataLoader(dataset=self.train_set, batch_size=self.batch_size, shuffle=True)

#     def val_dataloader(self):
#         return DataLoader(dataset=self.val_set, batch_size=self.batch_size, shuffle=False)

#     def test_dataloader(self):
#         return DataLoader(dataset=self.test_set, batch_size=self.batch_size, shuffle=False)

In [5]:
# data_module = ChestMNISTDataModule(batch_size=16)

# train_dataloader = data_module.train_dataloader()

In [6]:
# data = JSRTDataModule(data_dir='./data/JSRT/', batch_size=2)
# batch = next(iter(train_dataloader))

# sample = batch[5]
# print(sample.shape)
# patches = create_patch_sequence(sample)
# patches = torch.randn((2, 196, 20))

# masked_patches, mask_indices, num_masked_tokens, mask = create_masks(patches, 0.75)

In [7]:
# view masked image
# import matplotlib.pyplot as plt
# import matplotlib

# image_num = 0
# image = sample[image_num].squeeze()

# f, ax = plt.subplots(1, 2, figsize=(10, 10))

# ax[0].imshow(image, cmap=matplotlib.cm.gray)
# ax[0].axis('off')
# ax[0].set_title('image')

# ax[1].imshow(reversed_pathced_image.squeeze(), cmap=matplotlib.cm.gray)
# ax[1].axis('off')
# ax[1].set_title('masked image')

In [8]:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    if isinstance(grid_size, int):
        grid_size = (grid_size[0], grid_size[1])
    grid_h = np.arange(grid_size[0], dtype=np.float32)
    grid_w = np.arange(grid_size[1], dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

In [9]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        embed_dim=1024,
        num_channels=3,
        num_heads=16,
        depth=24,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        norm_layer = nn.LayerNorm,
        mlp_ratio=4.0,
        patch_size=16,
        norm_pix_loss=False,
        dropout=0.0,
    ):
        """Vision Transformer.

        Args:
            embed_dim: Dimensionality of the input feature vectors to the Transformer
            hidden_dim: Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels: Number of channels of the input (3 for RGB)
            num_heads: Number of heads to use in the Multi-Head Attention block
            num_layers: Number of layers to use in the Transformer
            num_classes: Number of classes to predict
            patch_size: Number of pixels that the patches have per dimension
            num_patches: Maximum number of patches an image can have
            dropout: Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size
        self.in_channels = num_channels

        # -------ENCODER PART------------------------------------------------------
        self.patch_embed = PatchEmbed(img_size, patch_size, num_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.embed_dim = embed_dim
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)

        self.blocks = nn.ModuleList(
            [
                Block(
                    embed_dim,
                    num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )

        self.norm = norm_layer(embed_dim)

        # -------DECODER PART------------------------------------------------------
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList(
            [
                Block(
                    decoder_embed_dim,
                    decoder_num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                )
                for i in range(decoder_depth)
            ]
        )

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * num_channels, bias=True)  # decoder to patch
        # --------------------------------------------------------------------------

        self.norm_pix_loss = norm_pix_loss

        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(
            embed_dim=self.pos_embed.shape[-1],
            grid_size=self.patch_embed.grid_size,
            cls_token=True,
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        decoder_pos_embed = get_2d_sincos_pos_embed(
            self.decoder_pos_embed.shape[-1],
            grid_size=self.patch_embed.grid_size,
            cls_token=True,
        )
        self.decoder_pos_embed.data.copy_(
            torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
        )

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=0.02)
        torch.nn.init.normal_(self.mask_token, std=0.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def create_patch_sequence(self, x):
        patch_height, patch_width = self.patch_embed.patch_size[0], self.patch_embed.patch_size[1]

        # unfold 2 extracts patches from height dimension while unfold 3 extracts patches from width dimension
        patches = x.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
        # ensure patches are contiguous
        patches = patches.contiguous().view(-1, x.size(0), x.size(1), patch_height, patch_width)
        # permute to have shape (batch_size, num_patches, channels, patch_size, patch_size)
        patches = patches.permute(1, 0, 2, 3, 4)

        # create sequence of pixels
        n, c, h, w = ((x.size(0), x.size(1), x.size(2) // patch_height, x.size(3) // patch_width))

        x = x.reshape(shape=(n, c, h, patch_height, w, patch_width))
        x = torch.einsum("nchpwq->nhwpqc", x)
        x = x.reshape(shape=(n, h * w, patch_height* patch_height * c))

        return x

    def reverse_patch_sequence(self, x):
        # fold patches back to image
        c = self.in_channels
        p, q = self.patch_embed.patch_size[0], self.patch_embed.patch_size[1]
        h, w = self.patch_embed.img_size[0], self.patch_embed.img_size[1]
        n = x.shape[0] # batch size
        x = x.reshape(shape=(n, h // p, w // q, p, q, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(n, c, h, w))

        return imgs

    def create_masks(self, patches, mask_ratio=0.75):
        """mask specified proportion of patches"""
        device = patches.device

        batch_size, num_patches, _, = patches.shape
        num_masked_tokens = int(num_patches * mask_ratio)

        # Generate random values and sort to get mask indices
        random_values = torch.rand(batch_size, num_patches, device=device)
        indices = torch.argsort(random_values, dim=1)

        mask_indices = indices[:, :num_masked_tokens]

        # Create mask and scatter values to mask indices
        mask = torch.zeros(batch_size, num_patches, dtype=torch.bool, device=device)
        mask.scatter_(1, mask_indices, True)

        # Clone patches and apply mask
        masked_patches = patches.clone()
        masked_patches[mask] = 0

        # Get reverse indices for potential reconstruction
        ids_reverse = torch.argsort(indices, dim=1)

        return masked_patches, mask, ids_reverse, num_masked_tokens

    def encoder(self, x, mask_ratio=0.75):
        # Preprocess input

        x = self.patch_embed(x)

        # add positional embedding
        x = x + self.pos_embed[:, 1:, :]

        # perform random masking
        masked_image, mask, mask_indices, num_masked_tokens = self.create_masks(x, mask_ratio)

        # Add CLS token and positional encoding
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(masked_image.shape[0], -1, -1)
  
        x = torch.cat((cls_tokens, masked_image), dim=1)

        # Apply Transforrmer
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, mask_indices, num_masked_tokens
    
    def decoder(self, x, mask_indices, num_masked_tokens):
        device = x.device

        x = self.decoder_embed(x)
        
        # # add mask tokens    
        mask_tokens = self.mask_token.repeat(x.shape[0], num_masked_tokens, 1)

        # x[:, 1:, :] removes the first token (CLS token) from x
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1).to(device)  # no cls token
        mask_indices = mask_indices.to(device)
        x_ = torch.gather(x_, dim=1, index=mask_indices.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # predictor projection
        x = self.decoder_pred(x)

        # remove cls token
        x = x[:, 1:, :]

        return x
    
    def loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove,
        """

        target = self.create_patch_sequence(imgs)

        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.0e-6) ** 0.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches

        return loss

    def forward(self, imgs, mask_ratio=0.75):
        latent,  mask, mask_indices, num_masked_tokens = self.encoder(imgs, mask_ratio)
        pred = self.decoder(latent, mask_indices, num_masked_tokens)  # [N, L, p*p*1]
        loss = self.loss(imgs, pred, mask)
        return loss, pred, mask, latent

In [10]:
class ViTAE(LightningModule):
    def __init__(self, model_kwargs, learning_rate: float = 0.001):
        super().__init__()
        self.learning_rate = learning_rate
        self.save_hyperparameters()

        # define transformer block
        self.model = VisionTransformer(**model_kwargs)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, pred, mask, latent = self.model.forward(batch['image'])

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, pred, mask, latent = self.model.forward(batch['image'])
        self.log("val_loss", loss, prog_bar=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        pass

In [11]:
#seed_everything(42, workers=True)

# data = JSRTDataModule(data_dir='./data/JSRT/', batch_size=32)

# model = ViTAE(
#     model_kwargs={
#         'img_size': 224,
#         'embed_dim': 768,
#         'num_channels': 1,
#         'num_heads': 12,
#         'depth': 14,
#         'decoder_embed_dim': 512,
#         'decoder_depth': 8,
#         'decoder_num_heads': 16,
#         'norm_layer': nn.LayerNorm,
#         'mlp_ratio': 4.0,
#         'patch_size': 16,
#         'norm_pix_loss': False,
#         'dropout': 0.0,
#     },
#     learning_rate=1e-3,
# )

# trainer = Trainer(
#     max_epochs=1,
#     accelerator='auto',
#     devices=[0, 1],
#     strategy='ddp_notebook',
#     log_every_n_steps=4,
#     check_val_every_n_epoch=50,
#     #save_top_k=1,
#     logger=TensorBoardLogger(save_dir='./lightning_logs/autoencoder/', name='ViTAE'),
#     #callbacks=[ModelCheckpoint(monitor="val_loss", mode='min'), TQDMProgressBar(refresh_rate=4)],
# )
# trainer.fit(model=model, datamodule=data)

#trainer.validate(model=model, datamodule=data, ckpt_path=trainer.checkpoint_callback.best_model_path)

#trainer.test(model=model, datamodule=data, ckpt_path=trainer.checkpoint_callback.best_model_path)

In [12]:
DEVICE = torch.device('cuda:1')

saved_model = ViTAE.load_from_checkpoint('/vol/bitbucket/bc1623/project/semi_supervised_uncertainty/lightning_logs/autoencoder/ViTAE/version_54/checkpoints/epoch=799-step=4800.ckpt',
    model_kwargs={
        'img_size': 224,
        'embed_dim': 1024,
        'num_channels': 1,
        'num_heads': 16,
        'depth': 16,
        'decoder_embed_dim': 512,
        'decoder_depth': 10,
        'decoder_num_heads': 16,
        'norm_layer': nn.LayerNorm,
        'mlp_ratio': 4.0,
        'patch_size': 16,
        'norm_pix_loss': False,
        'dropout': 0.0,
    },
    learning_rate=1e-4,
    map_location = DEVICE
    )

saved_model.eval()

ViTAE(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 1024, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0-15): 16 x Block(
        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (drop1): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (fc2)

In [13]:
saved_model.eval()
batch = next(iter(data.test_dataloader()))
with torch.no_grad():
    encoded_embeddings = saved_model.model.encoder(batch['image'].to(DEVICE)) # [batch_size, num_patches + 1, embed_dim]

NameError: name 'data' is not defined

In [None]:
import matplotlib.pyplot as plt
import matplotlib
#from matplotlib.cm import ScalarMappable
image_num = 0
image = batch['image'][image_num].squeeze()

f, ax = plt.subplots(1, 3, figsize=(10, 10))

ax[0].imshow(image, cmap=matplotlib.cm.gray)
ax[0].axis('off')
ax[0].set_title('image')

masked_patches, _, masked_indices, num_masked_tokens = create_masks(create_patch_sequence(batch['image']), 0.75)
masked_patches = reverse_patch_sequence(masked_patches[0, :, :].unsqueeze(0))

ax[1].imshow(masked_patches.squeeze().cpu(), cmap=matplotlib.cm.gray)
ax[1].axis('off')
ax[1].set_title('e.g. masked image')

# reconstructed image
with torch.no_grad():
    latent, mask, mask_indices, num_masked_tokens = saved_model.model.encoder(batch['image'][0].unsqueeze(0).to(DEVICE))
    decoded_image = saved_model.model.decoder(latent, mask_indices, num_masked_tokens)
reconstructed_image = reverse_patch_sequence(decoded_image)
ax[2].imshow(reconstructed_image.squeeze().cpu().detach().numpy(), cmap=matplotlib.cm.gray)
ax[2].axis('off')
ax[2].set_title('reconstructed')

In [28]:
# A list to store the attention weights
attention_weights = []

# Define a hook to capture attention weights
def get_attention_weights(module, input, output):
    attention = module.attn_drop.detach().cpu().numpy()
    attention_weights.append(attention)

In [29]:
data = JSRTDataModule(data_dir='./data/JSRT/', batch_size=2)
batch = next(iter(data.test_dataloader()))
# Register hooks to all attention layers
#for blk in saved_model.model.blocks:
saved_model.model.blocks[1].attn.register_forward_hook(get_attention_weights)

with torch.no_grad():
    latent, mask, mask_indices, num_masked_tokens = saved_model.model.encoder(batch['image'][0].unsqueeze(0).to(DEVICE))
    #decoded_image = saved_model.model.decoder(latent, mask_indices, num_masked_tokens)

True


Loading Data: 100%|██████████| 187/187 [00:00<00:00, 27002.27it/s]
Loading Data: 100%|██████████| 10/10 [00:00<00:00, 20039.68it/s]


Loading Data: 100%|██████████| 50/50 [00:00<00:00, 30539.57it/s]


Attention(
  (qkv): Linear(in_features=1024, out_features=3072, bias=True)
  (q_norm): Identity()
  (k_norm): Identity()
  (attn_drop): Dropout(p=0.0, inplace=False)
  (proj): Linear(in_features=1024, out_features=1024, bias=True)
  (proj_drop): Dropout(p=0.0, inplace=False)
) (tensor([[[ 7.8916e-02,  9.5279e-01,  1.8710e+00,  ...,  6.7404e-01,
           2.0214e+00,  2.1062e+00],
         [-6.6779e-01, -5.2936e-01, -1.5531e+00,  ...,  1.4771e+00,
           7.8221e-01,  1.4951e+00],
         [ 7.4483e-04,  5.9393e-04, -3.6893e-04,  ..., -4.3520e-05,
          -7.5248e-04, -3.4184e-04],
         ...,
         [ 7.4483e-04,  5.9393e-04, -3.6893e-04,  ..., -4.3520e-05,
          -7.5248e-04, -3.4184e-04],
         [ 7.4483e-04,  5.9393e-04, -3.6893e-04,  ..., -4.3520e-05,
          -7.5248e-04, -3.4184e-04],
         [ 7.4483e-04,  5.9393e-04, -3.6893e-04,  ..., -4.3520e-05,
          -7.5248e-04, -3.4184e-04]]], device='cuda:1'),) tensor([[[ 0.2668,  0.0719,  0.0276,  ..., -0.1156,  0.1

AttributeError: 'Dropout' object has no attribute 'softmax'

In [19]:
attention_weights[0].shape

torch.Size([1, 197, 1024])

In [None]:
# plot attention weights
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib

attention = attention_weights[0][0, 0]  # First block, first head
print(attention.shape)
attention = attention.reshape(32, 32)
# Plot the attention map
plt.figure(figsize=(10, 10))
sns.heatmap(attention, cmap='viridis')
plt.title('Attention Map of the First Head in the First Block')
plt.show()


In [32]:
saved_model.model.blocks[0].attn.qkv

Linear(in_features=1024, out_features=3072, bias=True)