In [1]:
import torch
import os
import yaml

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [3]:
from mri_dataset import ADNIDataset
from monai.transforms import *
from torch.utils.data import DataLoader, random_split


def get_data_loaders(batch_size=256):
    dataset_dir = r"E:\Data\ADNI\adni-fnirt-corrected"
    csv_path = r"E:\Data\ADNI\single_subject.csv"
    size = 100
    data_transforms = Compose([
        # RandRotate90(prob=0.5, spatial_axes=[1, 2]),
        # RandFlip(prob=0.5, spatial_axis=0),
        
        # RandAdjustContrast(prob=0.5),
        # RandGaussianNoise(prob=0.3),
        # RandAffine(prob=0.5, translate_range=10, scale_range=(0.9, 1.1), rotate_range=45),
        
        Resize(spatial_size=[size, size, size]),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ])
    dataset = ADNIDataset(data_dir=dataset_dir, csv_path=csv_path, transform=data_transforms)
    dataset_size = len(dataset)
    train_size = int(dataset_size * 0.7)
    test_size = dataset_size - train_size
    print('dataset_size:', dataset_size)
    print('train_size:', train_size)
    print('test_size:', test_size)
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [4]:
log_dir = './runs/simclr_vgg_150'

with open(os.path.join(log_dir, 'config.yml')) as file:
  config = yaml.load(file, Loader=yaml.SafeLoader)
print(config)

{'arch': 'vgg', 'batch_size': 8, 'csv_path': 'E:\\Data\\ADNI\\pheno_ADNI_longitudinal_new.csv', 'dataset_dir': 'E:\\Data\\ADNI\\adni-fnirt-corrected', 'dataset_name': 'mri', 'device': 'cuda', 'disable_cuda': False, 'epochs': 150, 'fp16_precision': False, 'learning_rate': 1e-05, 'log_every_n_steps': 100, 'n_views': 2, 'out_dim': 128, 'temperature': 0.07, 'weight_decay': 0.0001}


In [5]:
from model import Simple3DCNN, VoxVGG, VoxResNet

if config['arch'] == 'simple':
    model = Simple3DCNN(class_nums=3)
elif config['arch'] == 'vgg':
    model = VoxVGG(class_nums=3)
elif config['arch'] == 'resnet':
    model = VoxResNet(class_nums=3)

model = model.to(device)
print(model)

VoxVGG(
  (conv1): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv2): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv3): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv4): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv5): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv6): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv7): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv8): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv9): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (conv10): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (fc): Linear(in_features=64, out_features=128, bias=True)
  (bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (last_fc): Linear(in_features=128, out_features=3, bias=True)
)


In [6]:
checkpoint_filename = 'checkpoint_{:04}.pth.tar'.format(config['epochs'])
checkpoint_path = os.path.join(log_dir, checkpoint_filename)
print(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint['state_dict']
print('keys:', list(state_dict.keys()))

for k in list(state_dict.keys()):
  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.last_fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
    else:
      print(k)
  del state_dict[k]

./runs/simclr_vgg_150\checkpoint_0150.pth.tar
keys: ['backbone.conv1.weight', 'backbone.conv1.bias', 'backbone.conv2.weight', 'backbone.conv2.bias', 'backbone.conv3.weight', 'backbone.conv3.bias', 'backbone.conv4.weight', 'backbone.conv4.bias', 'backbone.conv5.weight', 'backbone.conv5.bias', 'backbone.conv6.weight', 'backbone.conv6.bias', 'backbone.conv7.weight', 'backbone.conv7.bias', 'backbone.conv8.weight', 'backbone.conv8.bias', 'backbone.conv9.weight', 'backbone.conv9.bias', 'backbone.conv10.weight', 'backbone.conv10.bias', 'backbone.fc.weight', 'backbone.fc.bias', 'backbone.bn.weight', 'backbone.bn.bias', 'backbone.bn.running_mean', 'backbone.bn.running_var', 'backbone.bn.num_batches_tracked', 'backbone.last_fc.0.weight', 'backbone.last_fc.0.bias', 'backbone.last_fc.2.weight', 'backbone.last_fc.2.bias']
backbone.last_fc.0.weight
backbone.last_fc.0.bias
backbone.last_fc.2.weight
backbone.last_fc.2.bias


In [7]:
log = model.load_state_dict(state_dict, strict=False)
print(log.missing_keys)
assert log.missing_keys == ['last_fc.weight', 'last_fc.bias']

['last_fc.weight', 'last_fc.bias']


In [8]:
if config['dataset_name'] == 'mri':
    train_loader, test_loader = get_data_loaders(batch_size=config['batch_size'] * 2)

dataset_size: 980
train_size: 686
test_size: 294


In [9]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False
    else:
        print(name)

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

fc.weight
fc.bias


In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [11]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [12]:
epochs = 50
for epoch in range(epochs):
    top1_train_accuracy = 0
    print(f'epoch:{epoch + 1}, train')
    for counter, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)
        loss = criterion(logits, y_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= counter + 1
    top1_accuracy = 0
    top3_accuracy = 0
    print(f'epoch:{epoch + 1}, test')
    for counter, (x_batch, y_batch) in enumerate(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)

        top1, top3 = accuracy(logits, y_batch, topk=(1, 3))
        top1_accuracy += top1[0]
        top3_accuracy += top3[0]

    top1_accuracy /= counter + 1
    top3_accuracy /= counter + 1
    print(
        f"Epoch {epoch + 1}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop3 test acc: {top3_accuracy.item()}"
    )
    state = {
        'model': model.state_dict(),
    }
    torch.save(state, os.path.join('checkpoint', 'checkpoint_eval.pth'))

epoch:1, train
epoch:1, test
Epoch 1	Top1 Train accuracy 38.82890319824219	Top1 Test accuracy: 39.47368621826172	Top3 test acc: 100.0
epoch:2, train
epoch:2, test
Epoch 2	Top1 Train accuracy 47.21760940551758	Top1 Test accuracy: 44.51754379272461	Top3 test acc: 100.0
epoch:3, train
epoch:3, test
Epoch 3	Top1 Train accuracy 53.50913619995117	Top1 Test accuracy: 48.355262756347656	Top3 test acc: 100.0
epoch:4, train
epoch:4, test
Epoch 4	Top1 Train accuracy 55.481727600097656	Top1 Test accuracy: 45.94298553466797	Top3 test acc: 100.0
epoch:5, train
epoch:5, test
Epoch 5	Top1 Train accuracy 52.86545181274414	Top1 Test accuracy: 45.28508758544922	Top3 test acc: 100.0
epoch:6, train
epoch:6, test
Epoch 6	Top1 Train accuracy 55.793190002441406	Top1 Test accuracy: 45.72368621826172	Top3 test acc: 100.0
epoch:7, train
epoch:7, test
Epoch 7	Top1 Train accuracy 52.595516204833984	Top1 Test accuracy: 48.135963439941406	Top3 test acc: 100.0
epoch:8, train
epoch:8, test
Epoch 8	Top1 Train accuracy 