In [2]:
import medmnist
from medmnist import INFO

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np

# chestmnist, retinamnist
def get_image_mean_std(dataname):
    info = INFO[dataname]
    DataClass = getattr(medmnist, info['python_class'])

    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((224, 224))
            ])

    train_dataset = DataClass(split='train', transform=transform, download=True)

    train_loader = data.DataLoader(dataset=train_dataset, batch_size=8192)
    total = info['n_samples']['train']
    mean = torch.zeros(info['n_channels'])
    std = torch.zeros(info['n_channels'])
    for images, _ in train_loader:
        num_img = len(images)
        m, s = images.mean([0,2,3]), images.std([0,2,3])
        mean += num_img * m / total
        std += np.sqrt(num_img/total) * s
    return mean, std

In [19]:
mean_stds = {}

for k in ['pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist']:
    mean, std = get_image_mean_std(k)
    print(k, mean, std)
    mean_stds[k] = {
        'mean': mean,
        'std': std
    }
print(mean_stds)

Using downloaded and verified file: /Users/naomileow/.medmnist/pathmnist.npz
pathmnist tensor([0.7405, 0.5330, 0.7058]) tensor([0.3920, 0.5636, 0.3959])
Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to /Users/naomileow/.medmnist/chestmnist.npz


  0%|          | 0/82802576 [00:00<?, ?it/s]

chestmnist tensor([0.4936]) tensor([0.7392])
Downloading https://zenodo.org/record/6496656/files/dermamnist.npz?download=1 to /Users/naomileow/.medmnist/dermamnist.npz


  0%|          | 0/19725078 [00:00<?, ?it/s]

dermamnist tensor([0.7631, 0.5381, 0.5614]) tensor([0.1354, 0.1530, 0.1679])
Downloading https://zenodo.org/record/6496656/files/octmnist.npz?download=1 to /Users/naomileow/.medmnist/octmnist.npz


  0%|          | 0/54938180 [00:00<?, ?it/s]

octmnist tensor([0.1889]) tensor([0.6606])
Downloading https://zenodo.org/record/6496656/files/pneumoniamnist.npz?download=1 to /Users/naomileow/.medmnist/pneumoniamnist.npz


  0%|          | 0/4170669 [00:00<?, ?it/s]

pneumoniamnist tensor([0.5719]) tensor([0.1651])
Downloading https://zenodo.org/record/6496656/files/retinamnist.npz?download=1 to /Users/naomileow/.medmnist/retinamnist.npz


  0%|          | 0/3291041 [00:00<?, ?it/s]

retinamnist tensor([0.3984, 0.2447, 0.1558]) tensor([0.2952, 0.1970, 0.1470])
Downloading https://zenodo.org/record/6496656/files/breastmnist.npz?download=1 to /Users/naomileow/.medmnist/breastmnist.npz


  0%|          | 0/559580 [00:00<?, ?it/s]

breastmnist tensor([0.3276]) tensor([0.2027])
Downloading https://zenodo.org/record/6496656/files/bloodmnist.npz?download=1 to /Users/naomileow/.medmnist/bloodmnist.npz


  0%|          | 0/35461855 [00:00<?, ?it/s]

bloodmnist tensor([0.7943, 0.6597, 0.6962]) tensor([0.2930, 0.3292, 0.1541])
Downloading https://zenodo.org/record/6496656/files/tissuemnist.npz?download=1 to /Users/naomileow/.medmnist/tissuemnist.npz


  0%|          | 0/124962739 [00:00<?, ?it/s]

tissuemnist tensor([0.1020]) tensor([0.4443])
Downloading https://zenodo.org/record/6496656/files/organamnist.npz?download=1 to /Users/naomileow/.medmnist/organamnist.npz


  0%|          | 0/38247903 [00:00<?, ?it/s]

organamnist tensor([0.4678]) tensor([0.6105])
Downloading https://zenodo.org/record/6496656/files/organcmnist.npz?download=1 to /Users/naomileow/.medmnist/organcmnist.npz


  0%|          | 0/15527535 [00:00<?, ?it/s]

organcmnist tensor([0.4932]) tensor([0.3762])
Downloading https://zenodo.org/record/6496656/files/organsmnist.npz?download=1 to /Users/naomileow/.medmnist/organsmnist.npz


  0%|          | 0/16528536 [00:00<?, ?it/s]

organsmnist tensor([0.4950]) tensor([0.3779])


  0%|          | 0/32657407 [00:00<?, ?it/s]

In [1]:
from utils.device import get_device
from models.backbone.datasets import MEAN_STDS, DataSets
from models.backbone.trainer import Trainer

device = get_device()
MODEL_SAVE_PATH = 'models/backbone/pretrained'

In [3]:
import torchvision
import torch
from torchvision.models import ResNet50_Weights
from torch import nn

retina_ds = DataSets('retinamnist', MEAN_STDS) # 5 classes
backbone = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
backbone.fc = nn.Linear(2048, len(retina_ds.info['label']))

batch_size = 256
rtrainer = Trainer(backbone, retina_ds, batch_size, device, balance=False)

Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz


In [None]:
rtrainer.run_train(5)

In [20]:
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

(0.52, 0.7557038049792416, 1.3845270824432374)


In [None]:
import os

# (0.5125, 0.7247683647010932, 1.464751205444336) for non pretrained, non balanced
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

torch.save(rtrainer.best_model.state_dict(), os.path.join(MODEL_SAVE_PATH, 'retina_backbone_pretrained_bal.pkl'))

In [10]:
import torchvision

# 'chestmnist', 'pneumoniamnist', 'octmnist',  'retinamnist'
chest_ds = DataSets('chestmnist', MEAN_STDS)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), pretrained=False)
# patch for single channel
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

batch_size = 128
ctrainer = Trainer(backbone, chest_ds, batch_size, device)

ctrainer.run_train(30)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [10]:
import torchvision
import os
import torch

# [7996,  1950,  9261, 13914,  3988,  4375,   978,  3705,  3263, 1690,  1799,  1158,  2279,   144]
chest_ds = DataSets('chestmnist', mean_stds)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), weights=None)
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, 'cxr_backbone.pkl')))

batch_size = 256
ctrainer = Trainer(backbone, chest_ds, batch_size, device, balance=True)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [None]:
ctrainer.run_train(5)

In [14]:
ctrainer.run_eval(ctrainer.model, ctrainer.test_loader)

(0.9239799784755875, 0.6329840270919405, 0.7261211995067128)

In [None]:
# Load and save pretrained resnet from medclip
import torch
from torch import nn

from medclip import MedCLIPModel, MedCLIPVisionModel

# load MedCLIP-ResNet50
model = MedCLIPModel(vision_cls=MedCLIPVisionModel)
model.from_pretrained()

bconv_weight = model.vision_model.model.conv1.weight.mean(dim=1).unsqueeze(1)

# The resnet model was trained on CheXpert and MIMIC-CXR
backbone = model.vision_model.model
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.conv1.weight = nn.Parameter(bconv_weight)

torch.save(backbone.state_dict(), 'medclip_resnet50.pkl')

In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler

configs = {
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*10
query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
device = get_device()

In [2]:
import torchvision
import os
import torch
from models.backbone.trainer import DSTrainer
from utils.f1_loss import BalAccuracyLoss, MCCLoss, F1Loss

backbone = torchvision.models.resnet50(num_classes=len(config.classes_split_map['train']), weights=None)
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

mtrainer = DSTrainer(backbone, query_dataset.class_labels(), criterion=F1Loss(), device=device)

In [None]:
mtrainer.run_train(5, support_loader, query_loader, lr=8e-5, weight_decay=1e-2)

In [16]:
mtrainer.run_eval(mtrainer.model, query_loader, True)

Loss 0.8846505880355835 | F1 0.31754305958747864 | AUC 0.6390231661700312 | Specificity 0.8324683904647827 | Recall 0.3061603009700775 | Bal Acc 0.5693143606185913
Loss 0.8571950793266296 | F1 0.30419912934303284 | AUC 0.6471379088770807 | Specificity 0.8628817200660706 | Recall 0.29924583435058594 | Bal Acc 0.5810637474060059
Loss 0.8955227136611938 | F1 0.2928975224494934 | AUC 0.6467687791698298 | Specificity 0.8378623723983765 | Recall 0.29491549730300903 | Bal Acc 0.5663889646530151
Loss 0.79469233751297 | F1 0.3806449770927429 | AUC 0.6778729220863033 | Specificity 0.8693472146987915 | Recall 0.3819085955619812 | Bal Acc 0.625627875328064
Loss 0.8110535740852356 | F1 0.3501167893409729 | AUC 0.7095406714783695 | Specificity 0.8584957718849182 | Recall 0.33841440081596375 | Bal Acc 0.5984550714492798


(0.8486228585243225,
 0.32908029556274415,
 0.6640686895563228,
 0.8522110939025879,
 0.3241289258003235)

In [17]:
mtrainer.run_eval(mtrainer.best_model, query_loader, True)

Loss 0.8142952919006348 | F1 0.4863552153110504 | AUC 0.6676657085120417 | Specificity 0.7322176694869995 | Recall 0.5168277621269226 | Bal Acc 0.6245226860046387
Loss 0.7670930624008179 | F1 0.5071707963943481 | AUC 0.6955945824336179 | Specificity 0.7418272495269775 | Recall 0.5384732484817505 | Bal Acc 0.640150249004364
Loss 0.8211760520935059 | F1 0.4695371985435486 | AUC 0.6653073310357044 | Specificity 0.7258445024490356 | Recall 0.4860992133617401 | Bal Acc 0.6059718728065491
Loss 0.8326702117919922 | F1 0.46564146876335144 | AUC 0.6749339870345388 | Specificity 0.7001748085021973 | Recall 0.5016079545021057 | Bal Acc 0.6008913516998291
Loss 0.8312413692474365 | F1 0.4304371774196625 | AUC 0.6596898542511708 | Specificity 0.7387593984603882 | Recall 0.4461764693260193 | Bal Acc 0.5924679040908813


(0.8132951974868774,
 0.4718283712863922,
 0.6726382926534147,
 0.7277647256851196,
 0.49783692955970765)

In [21]:
torch.save(mtrainer.best_model, 'models/backbone/pretrained/vindr2/trained-backbone-ba1.pth')