In [1]:
from transformers import ViTMAE3DConfig, ViTMAE3DForPreTraining


# Initialize a ViT MAE vit-mae-base style configuration
config = ViTMAE3DConfig(
    image_size=91,
    num_channels=1,
    patch_size=7,
    embed_dim=768,
    mask_ratio=0.25
    # norm_pix_loss=False
)

# Initialize a model (with random weights) from the vit-mae-base style configuration
vit_mae = ViTMAE3DForPreTraining(config)

# # Access model's configuration
_configuration = vit_mae.config
print(_configuration)


  from .autonotebook import tqdm as notebook_tqdm
2023-06-13 14:00:04.742150: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-06-13 14:00:04.742187: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


ViTMAE3DConfig {
  "attention_probs_dropout_prob": 0.0,
  "decoder_hidden_size": 384,
  "decoder_intermediate_size": 1536,
  "decoder_num_attention_heads": 16,
  "decoder_num_hidden_layers": 8,
  "embed_dim": 768,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 91,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "mask_ratio": 0.25,
  "model_type": "vit_mae_3d",
  "norm_pix_loss": false,
  "num_attention_heads": 12,
  "num_channels": 1,
  "num_hidden_layers": 12,
  "patch_size": 7,
  "qkv_bias": true,
  "transformers_version": "4.30.0.dev0"
}



In [4]:
import re
import os
import csv
import numpy as np
import torch
from einops import rearrange
from torch.utils.data import Dataset
from monai.data import NibabelReader

class SPRINT_T1w_flat_Dataset:
    def __init__(self, data_dir, filenames, subjects_csv, mode='train', transform=None):
        self.data_dir = data_dir
        # read all nifti files in data_dir
        self.filenames = filenames
        self.transform = transform
                
        # read labels of each subject
        self.labels = {}
        with open(subjects_csv, "r") as fp:
            csv_reader = csv.reader(fp, delimiter=",")
            for row in csv_reader:
                if row[0] == "subject_id":
                    continue
                id = row[0]
                label = int(row[4])
                self.labels[id] = label

        # count how many label == 1
        progress = 0
        not_progress = 0
        for filename in self.filenames:
            id = re.search(r'subject_(\d{3})-(\d{4})', filename).group(2)
            if self.labels[id] == 1:
                progress += 1
            else:
                not_progress += 1
        print(f"Total subjects: {len(self.filenames)}, Progressing: {progress}, Not progressing: {not_progress}")

        # create image reader
        self.image_reader = NibabelReader()

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        # get subject id from filename using regex
        id = re.search(r'subject_(\d{3})-(\d{4})', filename).group(2)
        image = self.image_reader.read(os.path.join(self.data_dir, filename))
        
        image = image.get_fdata().astype(np.float32)
        image = torch.from_numpy(image)
        # leave only middle 91 channels in dimension 1
        image = image[:, 9:100, :]
        if self.transform:
            image = self.transform(image)
        image = rearrange(image, 'd h w -> 1 d h w')
        label = self.labels[id]

        return {'image': image, 'label': label}

In [5]:
from typing import Any
from torch.utils.data import DataLoader
import torchvision.transforms as T

# open files.txt which contains all the file paths
with open('/home/minghui/github/NMSS/vitmae/files.txt', 'r') as f:
    files = f.readlines()
files = [x.strip() for x in files]

data_dir = '/media/minghui/Data/Datasets/NMSS Study/yuxin/agg_normalized/'
label_csv = '/home/minghui/github/NMSS/vitmae/subject_list.csv'

# custom torch transform to select num_channels random channels
class RandomChannelSelect:
    def __init__(self, num_channels=8):
        self.num_channels = num_channels

    def __call__(self, img):
        # img is a 4D tensor of shape (1, C, H, W)
        # randomly select a starting channel
        start_channel = np.random.randint(0, img.shape[0] - self.num_channels)
        img = img[start_channel:start_channel+self.num_channels, :, :]
        return img

class RandomCrop3D:
    def __init__(self, size=64):
        self.size = size
    
    def __call__(self, img):
        # img is a 3D tensor of shape (x, y, z)
        # randomly select a starting channel
        x_start = np.random.randint(0, img.shape[0] - self.size)
        y_start = np.random.randint(0, img.shape[1] - self.size)
        z_start = np.random.randint(0, img.shape[2] - self.size)
        img = img[x_start:x_start+self.size, y_start:y_start+self.size, z_start:z_start+self.size]
        return img

class RandomDimensionPermute:
    def __init__(self):
        pass

    def __call__(self, img):
        # img is a 4D tensor of shape (1, C, H, W)
        # randomly permute the dimensions
        dims = ['x', 'y', 'z']
        np.random.shuffle(dims)
        img = rearrange(img, 'x y z -> {}'.format(' '.join(dims)))
        return img

train_transforms = T.Compose([
    # RandomCrop3D(80),
    # RandomDimensionPermute(),
    # RandomChannelSelect(16),
    # T.RandomHorizontalFlip(),
    # T.RandomVerticalFlip(),
    # T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10, -10, 10)),
    # T.RandomAdjustSharpness(0.5),
    T.GaussianBlur(3, sigma=(0.05, 0.2)),
    # random noise
    # T.RandomRotation(90),
    # T.RandomCrop(64),
])

train_dataset = SPRINT_T1w_flat_Dataset(data_dir, files, label_csv, mode='train', transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)

for batch in train_loader:
    image = batch['image']
    print(image.shape)
    label = batch['label']
    break

Total subjects: 751, Progressing: 326, Not progressing: 425
torch.Size([2, 1, 91, 91, 91])


In [6]:
def train(model, train_loader, optimizer, scheduler, epochs):
    loss_hist = []
    for epoch in range(epochs):
        # training
        model.train()
        for batch_idx, datum in enumerate(train_loader):
            image = datum['image'].cuda()
            label = datum['label'].cuda()
            optimizer.zero_grad()
            output = model(image)
            loss = output.loss
            loss_hist.append(loss.item())
            loss.backward()
            optimizer.step()
        scheduler.step()
        # avg loss 
        avg_loss = sum(loss_hist) / len(loss_hist)
        loss_hist.clear()
        print(f'Epoch: {epoch}, Loss: {avg_loss}')

import torch

epochs = 50

model = vit_mae.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5)

train(model, train_loader, optimizer, scheduler, epochs)

Epoch: 0, Loss: 2441.332264677007
Epoch: 1, Loss: 2385.692653899497
Epoch: 2, Loss: 2351.358447785073
Epoch: 3, Loss: 2432.1738834786925
Epoch: 4, Loss: 2126.4023198878513
Epoch: 5, Loss: 1551.0795330291098
Epoch: 6, Loss: 2320.979242852394
Epoch: 7, Loss: 2454.360929773209
Epoch: 8, Loss: 2358.950841700777
Epoch: 9, Loss: 1059.1752731647898
Epoch: 10, Loss: 506.33268802723984
Epoch: 11, Loss: 480.0423789328717
Epoch: 12, Loss: 462.70029749768844
Epoch: 13, Loss: 453.2670467457873
Epoch: 14, Loss: 448.8441321190367
Epoch: 15, Loss: 444.7785218421449
Epoch: 16, Loss: 440.58440220609623
Epoch: 17, Loss: 437.936207872756
Epoch: 18, Loss: 437.64095898892015
Epoch: 19, Loss: 434.38005609715236
Epoch: 20, Loss: 513.0692846419963
Epoch: 21, Loss: 592.7511921334774
Epoch: 22, Loss: 469.1528213176322
Epoch: 23, Loss: 452.560278222916
Epoch: 24, Loss: 447.7554191427028
Epoch: 25, Loss: 438.38647882989113
Epoch: 26, Loss: 433.5987190895892
Epoch: 27, Loss: 430.7721763773167
Epoch: 28, Loss: 418.9

In [None]:
# 50 more epochs
epochs = 50

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

train(model, train_loader, optimizer, scheduler, epochs)

In [7]:
# save model
from datetime import datetime
epochs = 50
dt = datetime.now().strftime("%Y%m%d-%H%M%S")
model.save_pretrained(f'./pretrained_vit_mae_3d_25pct_{epochs}_epochs_{dt}')
torch.save(model.state_dict(), f'./pretrained_vit_mae_25pct_{epochs}_epochs_{dt}.pth')