# Imports

In [1]:
import os, torch, random
import SimpleITK, re
import numpy as np
import matplotlib.pyplot as plt 
from torchvision import transforms, models
from torch.utils.data import DataLoader
import re
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.callbacks import EarlyStopping
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score

In [20]:
model_name = 'ViT'

In [2]:
SEED = 2024

def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

seed_everything(SEED)

## GPU

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


# Dataset

In [4]:
data_transforms = {
    'train': transforms.Compose([
        # Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] 
        # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        transforms.ToTensor(), 
        # transforms.RandomResizedCrop(224),
        # transforms.RandomHorizontalFlip(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(), # PIL Image or numpy.ndarray (H x W x C)
        # transforms.Resize(256),
        # transforms.CenterCrop(224)
    ]),
}

vit_transforms = transforms.Compose([
    # Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] 
    # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    transforms.ToTensor(), 
    # transforms.RandomResizedCrop(224),
    transforms.Resize(224),
    # transforms.RandomHorizontalFlip(),
])

In [21]:
if model_name == 'ViT': transform = vit_transforms
else: transform = None

In [23]:
from OASIS_2D.dataset import OASIS_Dataset

total_dataset = OASIS_Dataset(
    flag='all', seed=SEED, transform=transform, 
    vit=model_name=='ViT'
)

batch_size = 8
total_dataloader = DataLoader(
    total_dataset, batch_size=batch_size
)

Total 100, disease 50, healthy 50.


# Extract features 

## Model

In [6]:
if model_name == 'ViT':
    # https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
    model = models.vision_transformer.vit_b_16(weights='DEFAULT')
    num_ftrs = model.heads.head.in_features
    # Parameters of newly constructed modules have requires_grad=True by default
    model.heads.head = torch.nn.Linear(num_ftrs, 2)
    transform = vit_transforms
elif model_name == 'ResNet':
    model = models.resnet18(weights='DEFAULT')
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, 2)
    transform = None
else:
    raise NotImplementedError(f'Model {model_name} not implemented.')

# Here, we need to freeze all the network except the final layer. 
# We need to set requires_grad = False to freeze the parameters 
# so that the gradients are not computed in backward().
for param in model.parameters():
    param.requires_grad = False
    
model = model.eval().to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

## Extractor

In [10]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

In [15]:
# a list of node names from tracing the model in train mode, 
# and another from tracing the model in eval mode.
nodes = get_graph_node_names(model)[1]
return_nodes = {
    node: node for node in nodes[-2:]
}
extractor = create_feature_extractor(
    model, return_nodes=return_nodes
)
print(return_nodes)

{'getitem_5': 'getitem_5', 'heads.head': 'heads.head'}

In [27]:
all_outputs = []
for inputs, _ in total_dataloader:
    inputs = inputs.to(device)
    outputs = extractor(inputs)
    
    # only take the output before the final linear layer
    # squeeze to remove pooled dimension (512, 1, 1) -> (512)
    all_outputs.append(outputs[nodes[-2]].squeeze())
    # break
    
all_outputs = torch.vstack(all_outputs)
all_outputs = all_outputs.detach().cpu().numpy()

In [29]:
print(all_outputs.shape)
all_outputs

array([[-0.42797986,  0.3932665 ,  0.1960909 , ...,  0.52814484,
         0.72506505, -2.0537581 ],
       [-0.3423489 ,  0.351157  ,  0.11749738, ...,  0.42512724,
         0.79731154, -2.2009861 ],
       [ 0.13467292,  0.56735533, -0.06937069, ...,  0.6225311 ,
         1.1357019 , -2.8951213 ],
       ...,
       [ 0.13928331,  0.752598  , -0.1792116 , ...,  0.22843729,
         0.91875046, -2.4359908 ],
       [ 0.16787422,  0.24648038,  0.1279275 , ...,  0.42462593,
         0.8471592 , -2.2496784 ],
       [ 0.11609305,  0.39360955, -0.14874858, ...,  0.2628085 ,
         0.8786392 , -2.0765867 ]], dtype=float32)

## Save

In [30]:
features = {
    'patient_id': total_dataset.patient_ids,
    'day': total_dataset.days,
    'label': total_dataset.labels,
    'feature': all_outputs
}

In [31]:
result_dir = os.path.join('OASIS_2D', 'results', model_name)

if not os.path.exists(result_dir):
    os.makedirs(result_dir, exist_ok=True)
    
torch.save(features, os.path.join(result_dir, 'features.pt'))