In [2]:
import torch
import torchvision

from torchmil.datasets import ToyDataset

mnist_dataset = torchvision.datasets.MNIST(root='~/.cache/', train=True, download=False, transform=None)
data = mnist_dataset.data.numpy().astype(float) / 255.0
targets = mnist_dataset.targets.numpy()

# perform PCA

from sklearn.decomposition import PCA

pca = PCA(n_components=100)
data = data.reshape(data.shape[0], -1)
data = pca.fit_transform(data)

In [None]:
from torchmil.data import collate_fn

dataset = ToyDataset(data, targets, 2000, [0,1], 10)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [4]:
from torchmil.models import ABMIL

model = ABMIL((100,))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [5]:
def _get_args_names(fn):
    args_names = fn.__code__.co_varnames[:fn.__code__.co_argcount]
    args_names = args_names[1:] # remove self
    return args_names


class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, **kwargs):
        args_names = _get_args_names(self.model.compute_loss)
        args_dict = { name: kwargs[name] for name in args_names }
        return self.model.compute_loss(**args_dict)

    def compute_loss(self, **kwargs):
        args_names = _get_args_names(self.model.compute_loss)
        args_dict = { name: kwargs[name] for name in args_names }
        return self.model.compute_loss(**args_dict)

model = ModelWrapper(model)

In [6]:
from Trainer import Trainer

trainer = Trainer(model, optimizer)

In [None]:
trainer.train(50, dataloader)

In [1]:
from argparse import Namespace

config = Namespace()
config.dataset_name = 'rsna-features_resnet50'
config.val_prop = 0.1
config.seed = 1

from utils.datasets import load_dataset

test_dataset = load_dataset(config, mode='test')

In [None]:
count_missing = 0
count_neg = 0
for i in range(len(test_dataset)):
    y_inst = test_dataset[i]['y_inst']
    if sum(y_inst) < 0:
        count_missing += 1
    if sum(y_inst) == 0:
        count_neg += 1

print(count_missing)
print(count_neg)
print(len(test_dataset))

In [1]:
from argparse import Namespace

config = Namespace()
config.dataset_name = 'camelyon16-patches_512_preset-features_UNI'
config.val_prop = 0.1
config.seed = 1

from utils.datasets import load_dataset

train_dataset, val_dataset = load_dataset(config, mode='train_val')