In [2]:
import medmnist
from medmnist import INFO

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np

# chestmnist, retinamnist
def get_image_mean_std(dataname):
    info = INFO[dataname]
    DataClass = getattr(medmnist, info['python_class'])

    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((224, 224))
            ])

    train_dataset = DataClass(split='train', transform=transform, download=True)

    train_loader = data.DataLoader(dataset=train_dataset, batch_size=8192)
    total = info['n_samples']['train']
    mean = torch.zeros(info['n_channels'])
    std = torch.zeros(info['n_channels'])
    for images, _ in train_loader:
        num_img = len(images)
        m, s = images.mean([0,2,3]), images.std([0,2,3])
        mean += num_img * m / total
        std += np.sqrt(num_img/total) * s
    return mean, std

In [19]:
mean_stds = {}

for k in ['pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist']:
    mean, std = get_image_mean_std(k)
    print(k, mean, std)
    mean_stds[k] = {
        'mean': mean,
        'std': std
    }
print(mean_stds)

Using downloaded and verified file: /Users/naomileow/.medmnist/pathmnist.npz
pathmnist tensor([0.7405, 0.5330, 0.7058]) tensor([0.3920, 0.5636, 0.3959])
Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to /Users/naomileow/.medmnist/chestmnist.npz


  0%|          | 0/82802576 [00:00<?, ?it/s]

chestmnist tensor([0.4936]) tensor([0.7392])
Downloading https://zenodo.org/record/6496656/files/dermamnist.npz?download=1 to /Users/naomileow/.medmnist/dermamnist.npz


  0%|          | 0/19725078 [00:00<?, ?it/s]

dermamnist tensor([0.7631, 0.5381, 0.5614]) tensor([0.1354, 0.1530, 0.1679])
Downloading https://zenodo.org/record/6496656/files/octmnist.npz?download=1 to /Users/naomileow/.medmnist/octmnist.npz


  0%|          | 0/54938180 [00:00<?, ?it/s]

octmnist tensor([0.1889]) tensor([0.6606])
Downloading https://zenodo.org/record/6496656/files/pneumoniamnist.npz?download=1 to /Users/naomileow/.medmnist/pneumoniamnist.npz


  0%|          | 0/4170669 [00:00<?, ?it/s]

pneumoniamnist tensor([0.5719]) tensor([0.1651])
Downloading https://zenodo.org/record/6496656/files/retinamnist.npz?download=1 to /Users/naomileow/.medmnist/retinamnist.npz


  0%|          | 0/3291041 [00:00<?, ?it/s]

retinamnist tensor([0.3984, 0.2447, 0.1558]) tensor([0.2952, 0.1970, 0.1470])
Downloading https://zenodo.org/record/6496656/files/breastmnist.npz?download=1 to /Users/naomileow/.medmnist/breastmnist.npz


  0%|          | 0/559580 [00:00<?, ?it/s]

breastmnist tensor([0.3276]) tensor([0.2027])
Downloading https://zenodo.org/record/6496656/files/bloodmnist.npz?download=1 to /Users/naomileow/.medmnist/bloodmnist.npz


  0%|          | 0/35461855 [00:00<?, ?it/s]

bloodmnist tensor([0.7943, 0.6597, 0.6962]) tensor([0.2930, 0.3292, 0.1541])
Downloading https://zenodo.org/record/6496656/files/tissuemnist.npz?download=1 to /Users/naomileow/.medmnist/tissuemnist.npz


  0%|          | 0/124962739 [00:00<?, ?it/s]

tissuemnist tensor([0.1020]) tensor([0.4443])
Downloading https://zenodo.org/record/6496656/files/organamnist.npz?download=1 to /Users/naomileow/.medmnist/organamnist.npz


  0%|          | 0/38247903 [00:00<?, ?it/s]

organamnist tensor([0.4678]) tensor([0.6105])
Downloading https://zenodo.org/record/6496656/files/organcmnist.npz?download=1 to /Users/naomileow/.medmnist/organcmnist.npz


  0%|          | 0/15527535 [00:00<?, ?it/s]

organcmnist tensor([0.4932]) tensor([0.3762])
Downloading https://zenodo.org/record/6496656/files/organsmnist.npz?download=1 to /Users/naomileow/.medmnist/organsmnist.npz


  0%|          | 0/16528536 [00:00<?, ?it/s]

organsmnist tensor([0.4950]) tensor([0.3779])


  0%|          | 0/32657407 [00:00<?, ?it/s]

In [1]:
from utils.device import get_device
from models.backbone.datasets import MEAN_STDS, DataSets
from models.backbone.trainer import Trainer

device = get_device()
MODEL_SAVE_PATH = 'models/backbone/pretrained'

In [3]:
import torchvision
import torch
from torchvision.models import ResNet50_Weights
from torch import nn

retina_ds = DataSets('retinamnist', MEAN_STDS) # 5 classes
backbone = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
backbone.fc = nn.Linear(2048, len(retina_ds.info['label']))

batch_size = 256
rtrainer = Trainer(backbone, retina_ds, batch_size, device, balance=False)

Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz


In [None]:
rtrainer.run_train(5)

In [20]:
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

(0.52, 0.7557038049792416, 1.3845270824432374)


In [None]:
import os

# (0.5125, 0.7247683647010932, 1.464751205444336) for non pretrained, non balanced
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

torch.save(rtrainer.best_model.state_dict(), os.path.join(MODEL_SAVE_PATH, 'retina_backbone_pretrained_bal.pkl'))

In [10]:
import torchvision

# 'chestmnist', 'pneumoniamnist', 'octmnist',  'retinamnist'
chest_ds = DataSets('chestmnist', MEAN_STDS)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), pretrained=False)
# patch for single channel
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

batch_size = 128
ctrainer = Trainer(backbone, chest_ds, batch_size, device)

ctrainer.run_train(30)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [10]:
import torchvision
import os
import torch

# [7996,  1950,  9261, 13914,  3988,  4375,   978,  3705,  3263, 1690,  1799,  1158,  2279,   144]
chest_ds = DataSets('chestmnist', mean_stds)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), weights=None)
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, 'cxr_backbone.pkl')))

batch_size = 256
ctrainer = Trainer(backbone, chest_ds, batch_size, device, balance=True)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [None]:
ctrainer.run_train(5)

In [14]:
ctrainer.run_eval(ctrainer.model, ctrainer.test_loader)

(0.9239799784755875, 0.6329840270919405, 0.7261211995067128)

In [None]:
# Load and save pretrained resnet from medclip
import torch
from torch import nn

from medclip import MedCLIPModel, MedCLIPVisionModel

# load MedCLIP-ResNet50
model = MedCLIPModel(vision_cls=MedCLIPVisionModel)
model.from_pretrained()

bconv_weight = model.vision_model.model.conv1.weight.mean(dim=1).unsqueeze(1)

# The resnet model was trained on CheXpert and MIMIC-CXR
backbone = model.vision_model.model
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.conv1.weight = nn.Parameter(bconv_weight)

torch.save(backbone.state_dict(), 'medclip_resnet50.pkl')

In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler

configs = {
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*10
query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
device = get_device()

In [2]:
import torchvision
import os
import torch
from models.backbone.trainer import DSTrainer
from utils.f1_loss import BalAccuracyLoss

backbone = torchvision.models.resnet50(num_classes=len(config.classes_split_map['train']), weights=None)
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

mtrainer = DSTrainer(backbone, query_dataset.class_labels(), criterion=BalAccuracyLoss(), device=device)

In [10]:
mtrainer.run_eval(mtrainer.model, query_loader, True)

Loss 0.8178021907806396 | F1 0.3351583778858185 | AUC 0.602855014089035 | Specificity 0.7272143363952637 | Recall 0.4031338691711426 | Bal Acc 0.5651741027832031
Loss 0.8040915727615356 | F1 0.32222121953964233 | AUC 0.5738897543349903 | Specificity 0.7176041603088379 | Recall 0.3890199661254883 | Bal Acc 0.5533120632171631
Loss 0.8150610327720642 | F1 0.32819268107414246 | AUC 0.5927963819758275 | Specificity 0.685369610786438 | Recall 0.41243231296539307 | Bal Acc 0.5489009618759155
Loss 0.8141124248504639 | F1 0.31385940313339233 | AUC 0.6090060357042646 | Specificity 0.7192613482475281 | Recall 0.3877924680709839 | Bal Acc 0.5535268783569336
Loss 0.8236879706382751 | F1 0.3152325451374054 | AUC 0.5647615005108276 | Specificity 0.7081069946289062 | Recall 0.37620580196380615 | Bal Acc 0.5421563982963562


(0.8149510383605957,
 tensor(0.3229, device='mps:0'),
 0.588661737322989,
 0.7115112900733948,
 0.3937168836593628)

In [9]:
mtrainer.run_eval(mtrainer.best_model, query_loader, True)

Loss 0.7730595469474792 | F1 0.4065409302711487 | AUC 0.6377244623793866 | Specificity 0.6742991209030151 | Recall 0.5058967471122742 | Bal Acc 0.5900979042053223
Loss 0.792277455329895 | F1 0.38331103324890137 | AUC 0.6266736220474719 | Specificity 0.6769896745681763 | Recall 0.5041444301605225 | Bal Acc 0.5905670523643494
Loss 0.8309187889099121 | F1 0.3456721305847168 | AUC 0.6039829723119572 | Specificity 0.656503438949585 | Recall 0.47546499967575073 | Bal Acc 0.5659842491149902
Loss 0.7993811368942261 | F1 0.3438308537006378 | AUC 0.5979420461276035 | Specificity 0.6408977508544922 | Recall 0.469146728515625 | Bal Acc 0.5550222396850586
Loss 0.7731005549430847 | F1 0.3598782420158386 | AUC 0.6240732889015435 | Specificity 0.6538779735565186 | Recall 0.48356109857559204 | Bal Acc 0.5687195062637329


(0.7937474966049194,
 tensor(0.3678, device='mps:0'),
 0.6180792783535926,
 0.6605135917663574,
 0.48764280080795286)

In [3]:
mtrainer.run_train(10, support_loader, query_loader, lr=1e-4, weight_decay=1e-2)

Batch 1: loss 0.8280101418495178
Batch 2: loss 0.8292435109615326
Batch 3: loss 0.8295965592066447
Batch 4: loss 0.8288757503032684
Batch 5: loss 0.8278212785720825
Batch 6: loss 0.8269070088863373
Batch 7: loss 0.8257842574800763
Batch 8: loss 0.8247404843568802
Batch 9: loss 0.8234745727645026
Batch 10: loss 0.8221038401126861
Batch 11: loss 0.8206689737059853
Batch 12: loss 0.8191321740547816
Batch 13: loss 0.81712773671517
Batch 14: loss 0.8156357535294124
Batch 15: loss 0.8143048286437988
Batch 16: loss 0.8137635849416256
Batch 17: loss 0.8123951484175289
Batch 18: loss 0.811598002910614
Batch 19: loss 0.8101978678452341
Batch 20: loss 0.8089416712522507
Batch 21: loss 0.8080876100630987
Batch 22: loss 0.8072474815628745
Batch 23: loss 0.806087763413139
Batch 24: loss 0.8052229334910711
Batch 25: loss 0.8045683622360229
Batch 26: loss 0.8032535773057204
Batch 27: loss 0.8025501524960553
Batch 28: loss 0.8007806433098656
Batch 29: loss 0.7994909224839046
Batch 30: loss 0.7983721097

  output, inverse_indices, counts = torch._unique2(
  denom[denom == 0.0] = 1


Epoch 1: Validation loss 0.7927254676818848 | F1 0.43386778235435486 | AUC 0.5992680664118906 | Acc H-Mean 0.5316094630482393
Batch 1: loss 0.6033168435096741
Batch 2: loss 0.6056607067584991
Batch 3: loss 0.6068258484204611
Batch 4: loss 0.6138665080070496
Batch 5: loss 0.6156848430633545
Batch 6: loss 0.6145555178324381
Batch 7: loss 0.6182503274508885
Batch 8: loss 0.6162194386124611
Batch 9: loss 0.6142563356293572
Batch 10: loss 0.6154501378536225
Batch 11: loss 0.6160964153029702
Batch 12: loss 0.616858571767807
Batch 13: loss 0.61766189795274
Batch 14: loss 0.6189858402524676
Batch 15: loss 0.6176013469696044
Batch 16: loss 0.6182028129696846
Batch 17: loss 0.6190365973640891
Batch 18: loss 0.6184763179885017
Batch 19: loss 0.6181828661968833
Batch 20: loss 0.616915363073349
Batch 21: loss 0.6158973035358247
Batch 22: loss 0.6156895973465659
Batch 23: loss 0.6152722550475079
Batch 24: loss 0.6155804693698883
Batch 25: loss 0.6153893566131592
Batch 26: loss 0.6165638336768517
Bat

In [6]:
torch.save(mtrainer.best_model, 'models/backbone/pretrained/vindr2/trained-backbone-balacc.pth')