In [23]:
from transformers import ResNetConfig, ResNetForImageClassification

config = ResNetConfig(
    num_channels=91,
    num_classes=2,
)

model = ResNetForImageClassification(config)

_config = model.config
print(_config)

ResNetConfig {
  "depths": [
    3,
    4,
    6,
    3
  ],
  "downsample_in_first_stage": false,
  "embedding_size": 64,
  "hidden_act": "relu",
  "hidden_sizes": [
    256,
    512,
    1024,
    2048
  ],
  "layer_type": "bottleneck",
  "model_type": "resnet",
  "num_channels": 91,
  "num_classes": 2,
  "out_features": [
    "stage4"
  ],
  "out_indices": [
    4
  ],
  "stage_names": [
    "stem",
    "stage1",
    "stage2",
    "stage3",
    "stage4"
  ],
  "transformers_version": "4.30.0.dev0"
}



In [11]:
import random

# open files.txt which contains all the file paths
with open('/home/minghui/github/NMSS/vitmae/files.txt', 'r') as f:
    files = f.readlines()
files = [x.strip() for x in files]

# split data into train, val and test set
random.shuffle(files)
train_set = files[:int(len(files)*0.8)]
val_set = files[int(len(files)*0.8):int(len(files)*0.9)]
test_set = files[int(len(files)*0.9):]

In [12]:
import re
import os
import csv
import torch
import numpy as np
from einops import rearrange
from torch.utils.data import Dataset
from monai.data import NibabelReader

class SPRINT_T1w_flat_Dataset:
    def __init__(self, data_dir, filenames, subjects_csv, mode='train', transform=None):
        self.data_dir = data_dir
        # read all nifti files in data_dir
        self.filenames = filenames
        self.transform = transform
                
        # read labels of each subject
        self.labels = {}
        with open(subjects_csv, "r") as fp:
            csv_reader = csv.reader(fp, delimiter=",")
            for row in csv_reader:
                if row[0] == "subject_id":
                    continue
                id = row[0]
                label = int(row[4])
                # self.labels[id] = label
                self.labels[id] = torch.zeros(2)
                self.labels[id][label] = 1

        # count how many label == 1
        progress = 0
        not_progress = 0
        for filename in self.filenames:
            id = re.search(r'subject_(\d{3})-(\d{4})', filename).group(2)
            if self.labels[id][1] == 1:
                progress += 1
            else:
                not_progress += 1
        print(f"Total subjects: {len(self.filenames)}, Progressing: {progress}, Not progressing: {not_progress}")

        # create image reader
        self.image_reader = NibabelReader()

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        # get subject id from filename using regex
        id = re.search(r'subject_(\d{3})-(\d{4})', filename).group(2)
        image = self.image_reader.read(os.path.join(self.data_dir, filename))
        image = image.get_fdata().astype(np.float32)
        image = torch.from_numpy(image)
        # leave only middle 91 channels in dimension 1
        image = image[:, 9:100, :]
        if self.transform:
            image = self.transform(image)
        label = self.labels[id]

        return {'image': image, 'label': label}


In [16]:
from typing import Any
from torch.utils.data import DataLoader
import torchvision.transforms as T

# open files.txt which contains all the file paths
with open('/home/minghui/github/NMSS/vitmae/files.txt', 'r') as f:
    files = f.readlines()
files = [x.strip() for x in files]

data_dir = '/media/minghui/Data/Datasets/NMSS Study/yuxin/agg_normalized/'
label_csv = '/home/minghui/github/NMSS/vitmae/subject_list.csv'

# custom torch transform to select num_channels random channels
class RandomChannelSelect:
    def __init__(self, num_channels=8):
        self.num_channels = num_channels

    def __call__(self, img):
        # img is a 4D tensor of shape (1, C, H, W)
        # randomly select a starting channel
        start_channel = np.random.randint(0, img.shape[0] - self.num_channels)
        img = img[start_channel:start_channel+self.num_channels, :, :]
        return img

class RandomCrop3D:
    def __init__(self, size=64):
        self.size = size
    
    def __call__(self, img):
        # img is a 3D tensor of shape (x, y, z)
        # randomly select a starting channel
        x_start = np.random.randint(0, img.shape[0] - self.size)
        y_start = np.random.randint(0, img.shape[1] - self.size)
        z_start = np.random.randint(0, img.shape[2] - self.size)
        img = img[x_start:x_start+self.size, y_start:y_start+self.size, z_start:z_start+self.size]
        return img

class RandomDimensionPermute:
    def __init__(self):
        pass

    def __call__(self, img):
        # img is a 4D tensor of shape (1, C, H, W)
        # randomly permute the dimensions
        dims = ['x', 'y', 'z']
        np.random.shuffle(dims)
        img = rearrange(img, 'x y z -> {}'.format(' '.join(dims)))
        return img

custom_transforms = T.Compose([
    RandomCrop3D(64),
    RandomDimensionPermute(),
    # RandomChannelSelect(16),
    # T.RandomHorizontalFlip(),
    # T.RandomVerticalFlip(),
    # T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10, -10, 10)),
    # T.RandomAdjustSharpness(0.5),
    # T.GaussianBlur(3, sigma=(0.1, 0.5)),
    # T.RandomRotation(90),
    # T.RandomCrop(64),
])

# train_dataset = SPRINT_T1w_flat_Dataset(data_dir, files, label_csv, mode='train', transform=custom_transforms)
# val_dataset = SPRINT_T1w_flat_Dataset(data_dir, val_set, label_csv, 'val', transform=custom_transforms)
# test_dataset = SPRINT_T1w_flat_Dataset(data_dir, test_set, label_csv, 'test', transform=custom_transforms)

train_dataset = SPRINT_T1w_flat_Dataset(data_dir, files, label_csv, mode='train')
val_dataset = SPRINT_T1w_flat_Dataset(data_dir, val_set, label_csv, 'val')
test_dataset = SPRINT_T1w_flat_Dataset(data_dir, test_set, label_csv, 'test')

Total subjects: 751, Progressing: 326, Not progressing: 425
Total subjects: 75, Progressing: 33, Not progressing: 42
Total subjects: 76, Progressing: 29, Not progressing: 47


In [17]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# count number of progressing in train_dataset
progress = 0
for i in range(len(train_dataset)):
    if train_dataset[i]['label'][1] == 1:
        progress += 1
total, nonprogress = len(train_dataset), len(train_dataset) - progress
print(f"Total subjects: {total}, Progressing: {progress}, Not progressing: {nonprogress}")

Total subjects: 751, Progressing: 326, Not progressing: 425


In [29]:
from einops import rearrange
import torch.nn.functional as F

def train(model, train_loader, val_loader, optimizer, loss_fn, scheduler, epochs):
    loss_hist = []

    for epoch in range(epochs):
        # training
        model.train()
        train_correct = train_total = 0
        for batch_idx, datum in enumerate(train_loader):
            image = datum['image'].cuda()
            label = datum['label'].cuda()
            # print('[debug] label: ', label)
            # print('[debug] label.shape: ', label.shape)
            # print('[debug] image.shape: ', image.shape)
            optimizer.zero_grad()

            output = model(image)
            # pass through softmax
            output = F.softmax(output.logits, dim=-1)
            # print('[debug] output.shape: ', output.shape)

            loss = loss_fn(output, label)
            loss_hist.append(loss.item())
            loss.backward()
            # train accuracy
            pred = output.argmax(dim=-1)
            label = label.argmax(dim=-1)
            train_correct += (pred == label).sum()
            train_total += label.shape[0]
            optimizer.step()
        scheduler.step()
        # avg loss 
        avg_loss = sum(loss_hist) / len(loss_hist)
        loss_hist.clear()
        acc = train_correct / train_total
        print(f'Epoch: {epoch}, Loss: {avg_loss}, Accuracy: {train_correct}/{train_total}, {acc*100}%')

        # validation
        model.eval()
        with torch.no_grad():
            val_correct = val_total = 0
            for batch_idx, datum in enumerate(val_loader):
                image = datum['image'].cuda()
                label = datum['label'].cuda()
                output = model(image)
                output = F.softmax(output.logits, dim=-1)

                pred = output.argmax(dim=-1)
                label = label.argmax(dim=-1)
                val_correct += (pred == label).sum().item()
                val_total += label.size(0)
            val_acc = val_correct / val_total
            print(f'Validation Accuracy: {val_correct}/{val_total}, {val_acc*100}%')

import torch
from torch.nn import CrossEntropyLoss, BCELoss

model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
normalized_class_weights = torch.tensor([progress/total, nonprogress/total]).cuda()
loss_fn = CrossEntropyLoss(weight=normalized_class_weights)
# loss_fn = BCELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

epochs = 400
train(model, train_loader, val_loader, optimizer, loss_fn, scheduler, epochs)

Epoch: 0, Loss: 0.40465514076516984, Accuracy: 406/751, 54.06125259399414%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 1, Loss: 0.3995908990185312, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 2, Loss: 0.3995370668299655, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 3, Loss: 0.3995909037742209, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 4, Loss: 0.39953707253679316, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 5, Loss: 0.3994294116471676, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 6, Loss: 0.3995370715856552, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 7, Loss: 0.3994832387629976, Accuracy: 425/751, 56.59120559692383%
Validation Accuracy: 42/75, 56.00000000000001%
Epoch: 8, Loss: 0.399698561176