In [2]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import ml_collections 
import deepchest
import os

In [3]:
import copy
import time

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # load pretrained model for feature extraction
        self.feature_extractor = torchvision.models.resnet50(pretrained=True)
        # freeze feature extractor part
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        # replace last layer with an indentity layer (to remove the last fc layer)
        num_ftrs = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Identity(num_ftrs)

        # embedding layer for position from sites inputs
        embedding_dim = 8
        self.pos_embedding = nn.Embedding(num_embeddings=12, embedding_dim=embedding_dim)

        # add new fc layers
        self.fc1 = nn.Linear(num_ftrs + embedding_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, images, sites, masks):
        # images is a tensor of batch_size x num_sites x 3 x 224 x 224

        feature_vectors = []
        for i in range(images.shape[1]):
            each_image_site = images[:, i, :, :, :]
            each_image_site = each_image_site.view(images.shape[0], images.shape[2], images.shape[3], images.shape[4])

            # x is now batch_size x 3 x 224 x 224
            x = self.feature_extractor(each_image_site)
            feature_vectors.append(x)

        # position embedding of size batch_size x num_sites x embedding_dim
        embedded_sites = self.pos_embedding(sites)

        # stack all feature vectors to a new dimension of size batch_size x num_sites x num_features (512 for ResNet18)
        x = torch.stack(feature_vectors, dim=1)

        # concatenate feature vectors and position embeddings (batch_size x num_sites x [num_features + embedding_dim])
        x = torch.concat([x, embedded_sites], dim=2)


        masks = masks == 1 # convert to boolean
        masks = masks.unsqueeze(-1).expand(x.size()) # expand from batch_size x num_sites to batch_size x num_sites x num_features (512 for ResNet18)

        x = x * masks # apply masks and preserver the original tensor dimensions

        # average all feature vectors
        x = torch.mean(x, dim=1) # batch_size x [num_features + embedding_dim]

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:

config = ml_collections.ConfigDict()

config.batch_size = 32
config.num_steps = 300

# See preprocessing.py, if you replace with ";" no preprocessing is done
config.preprocessing_train_eval = "independent_dropout(.2);"

config.use_validation_split = False

# If validation split is false, then train will have 4/5 of data and test 1/5
# If validation split is true, then train will have 3/5 of data, test 1/5 and val 1/5
config.num_folds = 5

# gpu workers
config.num_workers = 0

# dataset
config.images_directory = "dataset/images.dataset2/"
config.labels_file = "dataset/labels.dataset2/diagnostic.csv"

# Fold seed
config.random_state = 0

# Where the indices are saved
config.save_dir = "model_saved/"
config.export_folds_indices_file = "indices.csv"

# Don't modify these (should not have been in the config)
config.test_fold_index = 0
config.delta_from_test_index_to_validation_index = 1

In [7]:
model = Net()
checkpoint = torch.load(os.path.join(
            config.save_dir,
            f"best_model_resnet18_sigmoid_epoch{0}_test_fold_index{config.test_fold_index}.ds1",
        ))

In [13]:
import glob
for ds in [1, 2]:
    for epoch in range(4):
        checkpoint_path = os.path.join(
            config.save_dir,
            f"best_model_resnet18_sigmoid_epoch{epoch}_test_fold_index*{'.ds'+str(ds)+'-1'}"
            # {'.ds'+str(ds)+'-1' if ds == 1 else ''}",
        )
        list_checkpoints = glob.glob(checkpoint_path)
        accs = []
        for file in list_checkpoints:
            checkpoint = torch.load(file)
            accs.append(checkpoint['acc'])
        min_acc = min([(a.data, i) for i, a in enumerate(accs)])[0]
        min_idx = min([(a.data, i) for i, a in enumerate(accs)])[1]
        max_acc = max([(a.data, i) for i, a in enumerate(accs)])[0]
        max_idx = max([(a.data, i) for i, a in enumerate(accs)])[1]
        print(f'dataset {ds} epoch {epoch}: {torch.mean(torch.Tensor(accs))}\t(min {min_acc} index {min_idx}, max {max_acc} index {max_idx})')
    print('+++++++++++++')

dataset 1 epoch 0: 0.6875	(min 0.53125 index 0, max 0.84375 index 2)
dataset 1 epoch 1: 0.706250011920929	(min 0.625 index 1, max 0.875 index 0)
dataset 1 epoch 2: 0.7749999761581421	(min 0.65625 index 1, max 0.90625 index 4)
dataset 1 epoch 3: 0.7250000238418579	(min 0.59375 index 4, max 0.875 index 0)
+++++++++++++
dataset 2 epoch 0: 0.515625	(min 0.46875 index 3, max 0.59375 index 1)
dataset 2 epoch 1: 0.53125	(min 0.5 index 0, max 0.59375 index 1)
dataset 2 epoch 2: 0.5	(min 0.46875 index 3, max 0.53125 index 0)
dataset 2 epoch 3: 0.5859375	(min 0.5 index 1, max 0.71875 index 2)
+++++++++++++
