# Imports

In [1]:
import os, torch, random
import SimpleITK
import numpy as np
import matplotlib.pyplot as plt 
from torchvision import transforms, models
from torch.utils.data import DataLoader
import re
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

In [2]:
SEED = 2024

def seed_everything(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

seed_everything(SEED)

## GPU

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


# Dataset

In [4]:
data_transforms = {
    'train': transforms.Compose([
        # Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] 
        # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        transforms.ToTensor(), 
        # transforms.RandomResizedCrop(224),
        # transforms.RandomHorizontalFlip(),
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(), # PIL Image or numpy.ndarray (H x W x C)
        # transforms.Resize(256),
        # transforms.CenterCrop(224)
    ]),
}

In [123]:
from OASIS_2D.dataset import OASIS_Dataset

total_dataset = OASIS_Dataset(flag='all', seed=SEED)

batch_size = 8
total_dataloader = DataLoader(
    total_dataset, batch_size=batch_size
)

Total 100, disease 50, healthy 50.


# Extract features 

## Model

In [125]:
# https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
model = models.resnet18(weights='DEFAULT')

# Here, we need to freeze all the network except the final layer. 
# We need to set requires_grad = False to freeze the parameters 
# so that the gradients are not computed in backward().
for param in model.parameters():
    param.requires_grad = False
    
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model.fc = torch.nn.Linear(num_ftrs, 2)

model.to(device)

## Extractor

In [126]:
from torchvision.models.feature_extraction import create_feature_extractor

return_nodes = {
    "avgpool": "avgpool",
    "fc": "fc"
}
extractor = create_feature_extractor(
    model, return_nodes=return_nodes
)

In [127]:
all_outputs = []
for inputs, _ in total_dataloader:
    inputs = inputs.to(device)
    outputs = extractor(inputs)
    
    # only take the output before the final linear layer
    # squeeze to remove pooled dimension (512, 1, 1) -> (512)
    all_outputs.append(outputs['avgpool'].squeeze())
    
all_outputs = torch.vstack(all_outputs)
all_outputs = all_outputs.detach().cpu().numpy()

## Save

In [21]:
features = {
    'patient_id': total_dataset.patient_ids,
    'day': total_dataset.days,
    'label': total_dataset.labels,
    'feature': all_outputs
}

In [128]:
result_dir = os.path.join('OASIS_2D', 'scratch', model._get_name())

if not os.path.exists(result_dir):
    os.makedirs(result_dir, exist_ok=True)
    
torch.save(features, os.path.join(result_dir, 'features.pt'))

# Train

In [71]:
features = torch.load(os.path.join(result_dir, 'features.pt'))

# Dataset

In [134]:
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class OASIS_TemporalDataset(Dataset):
    def __init__(self, features, train, seed=7):
        self.train = train
        self.seed = seed
        self.selected_patients = self.split(features)
        
        id_day_map = {}
        id_day_feature_map = {}

        for patient_id, day, feature, label in zip(
            features['patient_id'], features['day'], 
            features['feature'], features['label']
        ):
            if patient_id not in self.selected_patients: continue
            
            if patient_id not in id_day_map:
                id_day_map[patient_id] = [day]
            else:    
                id_day_map[patient_id].append(day)
                
            id_day_feature_map[(patient_id, day)] = (feature, label)
            
        self.id_day_map = id_day_map
        self.id_day_feature_map = id_day_feature_map
        self.id_days = list(id_day_feature_map.keys())
        
    def split(self, features):
        patient_ids, labels = [], []
        for i in range(len(features['patient_id'])):
            patient_id = features['patient_id'][i]
        
            if patient_id in patient_ids: continue
            
            patient_ids.append(patient_id)
            # assuming same patient id and day can't be in both healthy and disease dir
            labels.append(features['label'][i])
        
        train_ids, test_ids = train_test_split(
            patient_ids, test_size=0.2, shuffle=True, 
            random_state=self.seed, stratify=labels
        )
        
        if self.train: return train_ids
        else: return test_ids
        
    def __len__(self):
        return len(self.id_days)
    
    def __getitem__(self, idx):
        id, current_day = self.id_days[idx]
        
        # sort the current and previous days in ascending order
        prev_days = sorted([day for day in self.id_day_map[id] if day <= current_day])
        
        # get feature for each previous day
        features = [self.id_day_feature_map[(id, d)][0] for d in prev_days]    
        
        # predict the label for the current day
        label = self.id_day_feature_map[(id, current_day)][1]
        
        return torch.tensor(features[-1]), torch.tensor(label)

In [135]:
train_dataset = OASIS_TemporalDataset(features, train=True, seed=SEED)
test_dataset = OASIS_TemporalDataset(features, train=False, seed=SEED)

In [35]:
dimension = len(features['feature'][0])
print(f'Extracted feature has dimension {dimension}')

Extracted feature has dimension 512


# Model

In [136]:
class LstmModel(torch.nn.Module):
    def __init__(
        self, input_size=512, num_layers=1, 
        hidden_size=64, output_size=2, dropout=0.1
    ):
        super(LstmModel, self).__init__()
        # input shape is (batch, seq_len, features)
        self.lstm = torch.nn.LSTM(
            input_size=input_size, 
            hidden_size= hidden_size, 
            num_layers=num_layers,
            batch_first=True
        )
        self.dropout = torch.nn.Dropout(dropout)
        self.fc = torch.nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        # The output of nn.LSTM() is a tuple. The first 
        # element is the generated hidden states, 
        # one for each time step of the input. The 
        # second element is the LSTM cell’s memory 
        # and hidden states, which is not used here.
        # output (batch x hidden_size), (hc, cn)
        x, _ = self.lstm(x) 
        
        # The output of hidden states is further processed by a 
        # fully-connected layer to produce a single regression result. 
        # Since the output from LSTM is one per each input time step, 
        # you can chooce to pick only the last timestep’s output
        # x = self.fc(x[:, -1, :])
        x = self.fc(self.dropout(x))
        
        return x

In [90]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.callbacks import EarlyStopping
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score

In [141]:
model = LstmModel()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=10, verbose=False, save_dir=result_dir)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)

scheduler = ReduceLROnPlateau(
    optimizer, 'min', 
    patience=3, min_lr=1e-6, verbose=True
)

n_epochs = 100

for epoch in range(n_epochs):
    total_loss = []
    total_probs, total_preds, total_labels = [], [], []
    model.train()
    # training
    for x_train, y_train in train_dataloader:
        x_train, y_train = x_train.to(device), y_train.to(device)
        outputs = model(x_train)
        loss = criterion(outputs, y_train)
        optimizer.zero_grad()
        
        total_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        
        probs = torch.sigmoid(outputs)
        preds = torch.argmax(probs, dim=1)

        # comvert to numpy
        probs = probs.detach().cpu().numpy()
        preds = preds.detach().cpu().numpy()

        # append to list
        total_labels.extend(y_train.detach().cpu().numpy())
        total_probs.extend(probs[:, 1])
        total_preds.extend(preds)
        
    train_loss = np.mean(total_loss)
    
    # validation
    model.eval()
    total_loss = []
    
    with torch.no_grad():
        for x_test, y_test in test_dataloader:
            x_test, y_test = x_test.to(device), y_test.to(device)
            outputs = model(x_test)
            loss = criterion(outputs, y_test)
            total_loss.append(loss.item())
        
    val_loss = np.mean(total_loss)    
        
    if epoch % 10 == 0:
        # calculate metrics
        acc = accuracy_score(total_labels, total_preds)
        auc = roc_auc_score(total_labels, total_probs)
        f1 = f1_score(total_labels, total_preds)
        
        # print metrics
        print(
            f'Epoch {epoch+1}, loss {train_loss:0.4g} val_loss {val_loss:0.4g} f1 {f1:0.4g} acc {acc:0.4g} auc {auc:0.4g}'
        )
        
    # early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    scheduler.step(val_loss)

Epoch 1, loss 0.7253 val_loss 0.7938 f1 0.5135 acc 0.5325 auc 0.5162
EarlyStopping counter: 1 out of 10
EarlyStopping counter: 2 out of 10
EarlyStopping counter: 3 out of 10
EarlyStopping counter: 4 out of 10
Epoch 00007: reducing learning rate of group 0 to 1.0000e-04.
EarlyStopping counter: 5 out of 10
EarlyStopping counter: 6 out of 10
EarlyStopping counter: 7 out of 10
Epoch 11, loss 0.6318 val_loss 0.7593 f1 0.6769 acc 0.7273 auc 0.8014
EarlyStopping counter: 8 out of 10
Epoch 00011: reducing learning rate of group 0 to 1.0000e-05.
EarlyStopping counter: 9 out of 10
EarlyStopping counter: 10 out of 10
Early stopping
