In [1]:
import sys
sys.path.append("../src")

import os
from typing import Dict
from pathlib import Path

import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader

from utils import swin_transformer

In [2]:
label_dir = os.path.join("..", "data", "labels.csv")

labels = pd.read_csv(label_dir)
labels.head()

Unnamed: 0,id,grade
0,16425,1
1,16421,1
2,16223,1
3,16089,1
4,16026,1


In [3]:
labels.isna().sum()

id       0
grade    0
dtype: int64

In [4]:
ids = labels["id"].tolist()
grades = labels["grade"].map(lambda x: 0 if x == "1" else 1).tolist()

In [5]:
labels = {patient_id: grade for patient_id, grade in zip(ids, grades)}

labels

{'16425': 0,
 '16421': 0,
 '16223': 0,
 '16089': 0,
 '16026': 0,
 '15779': 0,
 '15513': 1,
 '15463A': 1,
 '15463B': 1,
 '15286': 1,
 '15060': 1,
 '15050': 0,
 '14917A': 1,
 '14917B': 1,
 '14917C': 1,
 '14732A': 1,
 '14732B': 1,
 '14703': 1,
 '14696A': 0,
 '14696B': 1,
 '14572A': 1,
 '14404': 0,
 '14148': 1,
 '14120': 0,
 '14080': 0,
 '14077': 0,
 '13982': 0,
 '13838': 0,
 '13820': 0,
 '13791': 1,
 '13742': 0,
 '13675': 0,
 '13663': 0,
 '13645': 0,
 '13554': 0,
 '13540': 0,
 '13475': 1,
 '13424': 0,
 '13353': 0,
 '13284': 0,
 '13267': 0,
 '13193': 1,
 '13191': 1,
 '13179': 0,
 '13153': 1,
 '13119': 0,
 '13055': 1,
 '13054A': 0,
 '13054B': 1,
 '12925': 0,
 '12904': 1,
 '12845': 1,
 '12801': 1,
 '12768': 0,
 '12691': 0,
 '12662': 1,
 '12615': 0,
 '12570': 0,
 '12529': 0,
 '12524': 1,
 '12447': 0,
 '12424': 0,
 '12404': 0,
 '12399': 0,
 '12380': 0,
 '12327': 1,
 '12220': 0,
 '12186': 0,
 '12180': 0,
 '12169': 0,
 '12145': 0,
 '12120': 0,
 '12063': 0,
 '12015': 0,
 '12010A': 1,
 '11987': 0,

In [6]:
class WSIDataset(Dataset):

    """
    Creates the dataset class for the dataloader.
    """

    def __init__(
        self, 
        data_dir: str, 
        label_dir: str,
        ):

        self.data_dir = data_dir
        self.patient_ids = os.listdir(data_dir)
        self.labels = self.generate_labels(label_dir)

        assert all([Path(i).stem in self.labels for i in self.patient_ids]), "All patient ids must have a label"

    
    def generate_labels(self, label_dir: str) -> Dict[str, str]:

        """
        Creates a dictionary containing the patient ids as keys
        and the associated Meningioma grade as the values.
        """

        labels = pd.read_csv(label_dir)
        ids = labels["id"].tolist()
        grades = labels["grade"].map(lambda x: 0 if x == "1" else 1).tolist()

        labels = {patient_id: grade for patient_id, grade in zip(ids, grades)}

        return labels
    
    def __len__(self):
        return len(self.patient_ids)
    
    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        label = self.labels[Path(patient_id).stem]

        embedding_path = os.path.join(self.data_dir, patient_id)
        embedding = torch.tensor(np.load(embedding_path)).permute(2, 0, 1)

        return embedding, label


In [7]:
data_dir = os.path.join("..", "data", "UNI", "trial-1", "train")
label_dir = os.path.join("..", "data", "labels.csv")

wsi_dataset = WSIDataset(data_dir, label_dir)
wsi_loader = DataLoader(wsi_dataset, batch_size=1, shuffle=True)

In [8]:
embedding, label = next(iter(wsi_loader))

In [9]:
version = "v1"
variant = "tiny"
num_classes = 2

model = swin_transformer(version, variant, num_classes)

In [10]:
model

SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(1024, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=96, out_features=288, bias=True)
          (proj): Linear(in_features=96, out_features=96, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=384, out_features=96, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
    

In [13]:
model(embedding)

tensor([[ 0.0605, -0.2405]], grad_fn=<AddmmBackward0>)

In [12]:
embedding.shape

torch.Size([1, 1024, 242, 194])