In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cpu


In [None]:
# data_dir_stl10 = r'C:\Custom\DataSet\STL10'
# data_dir_cifar10 = r'C:\Custom\DataSet\CIFAR10'

# def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
#   train_dataset = datasets.STL10(data_dir_stl10, split='train', download=download,
#                                   transform=transforms.ToTensor())

#   train_loader = DataLoader(train_dataset, batch_size=batch_size,
#                             num_workers=0, drop_last=False, shuffle=shuffle)
  
#   test_dataset = datasets.STL10(data_dir_stl10, split='test', download=download,
#                                   transform=transforms.ToTensor())

#   test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
#                             num_workers=10, drop_last=False, shuffle=shuffle)
#   return train_loader, test_loader

# def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
#   train_dataset = datasets.CIFAR10(data_dir_cifar10, train=True, download=download,
#                                   transform=transforms.ToTensor())

#   train_loader = DataLoader(train_dataset, batch_size=batch_size,
#                             num_workers=0, drop_last=False, shuffle=shuffle)
  
#   test_dataset = datasets.CIFAR10(data_dir_cifar10, train=False, download=download,
#                                   transform=transforms.ToTensor())

#   test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
#                             num_workers=10, drop_last=False, shuffle=shuffle)
#   return train_loader, test_loader

In [3]:
from mri_dataset import ADNIDataset
from monai.transforms import *
from torch.utils.data import DataLoader, random_split


def get_cifar10_data_loaders(batch_size=256):
    dataset_dir = r"E:\Data\ADNI\adni-fnirt-corrected"
    csv_path = r"E:\Data\ADNI\pheno_ADNI_longitudinal_new.csv"
    size = 100
    data_transforms = Compose([
        RandRotate90(prob=0.5, spatial_axes=[1, 2]),
        RandFlip(prob=0.5, spatial_axis=0),
        
        RandAdjustContrast(prob=0.5),
        RandGaussianNoise(prob=0.3),
        RandAffine(prob=0.5, translate_range=10, scale_range=(0.9, 1.1), rotate_range=45),
        
        Resize(spatial_size=[size, size, size]),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ])
    dataset = ADNIDataset(data_dir=dataset_dir, csv_path=csv_path, transform=data_transforms)
    dataset_size = len(dataset)
    train_size = int(dataset_size * 0.7)
    test_size = dataset_size - train_size
    print('dataset_size:', dataset_size)
    print('train_size:', train_size)
    print('test_size:', test_size)
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [4]:
log_dir = './runs/Mar19_12-04-45_DESKTOP-ZERO'

with open(os.path.join(log_dir, 'config.yml')) as file:
  config = yaml.load(file, Loader=yaml.SafeLoader)
print(config)

{'arch': 'resnet18', 'batch_size': 8, 'csv_path': 'E:\\Data\\ADNI\\pheno_ADNI_longitudinal_new.csv', 'dataset_dir': 'E:\\Data\\ADNI\\adni-fnirt-corrected', 'dataset_name': 'mri', 'device': 'cuda', 'disable_cuda': False, 'epochs': 20, 'fp16_precision': False, 'learning_rate': 0.0003, 'log_every_n_steps': 100, 'n_views': 2, 'out_dim': 128, 'temperature': 0.07, 'weight_decay': 0.0001}


In [None]:
# if config['arch'] == 'resnet18':
#   model = torchvision.models.resnet18(num_classes=10).to(device)
# elif config['arch'] == 'resnet50':
#   model = torchvision.models.resnet50(num_classes=10).to(device)

In [6]:
from model import Simple3DCNN_SIMCLR

if config['arch'] == 'resnet':
    pass
else:
    model = Simple3DCNN_SIMCLR(out_dim=config['out_dim'])

print(model)

Simple3DCNN_SIMCLR(
  (conv1): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (pool): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (fc1): Linear(in_features=27648, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=128, bias=True)
)


In [22]:
checkpoint_filename = 'checkpoint_{:04}.pth.tar'.format(config['epochs'])
checkpoint_path = os.path.join(log_dir, checkpoint_filename)
print(checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint['state_dict']
print('keys:', list(state_dict.keys()))

for k in list(state_dict.keys()):
  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
    else:
      print(k)
  del state_dict[k]

./runs/Mar19_12-04-45_DESKTOP-ZERO\checkpoint_0020.pth.tar
keys: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias']


In [None]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [None]:
if config['dataset_name'] == 'cifar10':
  train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config['dataset_name'] == 'stl10':
  train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config['dataset_name'])

In [None]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False
    else:
        print(name)

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [None]:
import tqdm

epochs = 1
for epoch in range(epochs):
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)
        loss = criterion(logits, y_batch)
        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= counter + 1
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)

        top1, top5 = accuracy(logits, y_batch, topk=(1, 5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]

    top1_accuracy /= counter + 1
    top5_accuracy /= counter + 1
    print(
        f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}"
    )