In [2]:
import medmnist
from medmnist import INFO

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np

# chestmnist, retinamnist
def get_image_mean_std(dataname):
    info = INFO[dataname]
    DataClass = getattr(medmnist, info['python_class'])

    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((224, 224))
            ])

    train_dataset = DataClass(split='train', transform=transform, download=True)

    train_loader = data.DataLoader(dataset=train_dataset, batch_size=8192)
    total = info['n_samples']['train']
    mean = torch.zeros(info['n_channels'])
    std = torch.zeros(info['n_channels'])
    for images, _ in train_loader:
        num_img = len(images)
        m, s = images.mean([0,2,3]), images.std([0,2,3])
        mean += num_img * m / total
        std += np.sqrt(num_img/total) * s
    return mean, std

In [19]:
mean_stds = {}

for k in ['pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist']:
    mean, std = get_image_mean_std(k)
    print(k, mean, std)
    mean_stds[k] = {
        'mean': mean,
        'std': std
    }
print(mean_stds)

Using downloaded and verified file: /Users/naomileow/.medmnist/pathmnist.npz
pathmnist tensor([0.7405, 0.5330, 0.7058]) tensor([0.3920, 0.5636, 0.3959])
Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to /Users/naomileow/.medmnist/chestmnist.npz


  0%|          | 0/82802576 [00:00<?, ?it/s]

chestmnist tensor([0.4936]) tensor([0.7392])
Downloading https://zenodo.org/record/6496656/files/dermamnist.npz?download=1 to /Users/naomileow/.medmnist/dermamnist.npz


  0%|          | 0/19725078 [00:00<?, ?it/s]

dermamnist tensor([0.7631, 0.5381, 0.5614]) tensor([0.1354, 0.1530, 0.1679])
Downloading https://zenodo.org/record/6496656/files/octmnist.npz?download=1 to /Users/naomileow/.medmnist/octmnist.npz


  0%|          | 0/54938180 [00:00<?, ?it/s]

octmnist tensor([0.1889]) tensor([0.6606])
Downloading https://zenodo.org/record/6496656/files/pneumoniamnist.npz?download=1 to /Users/naomileow/.medmnist/pneumoniamnist.npz


  0%|          | 0/4170669 [00:00<?, ?it/s]

pneumoniamnist tensor([0.5719]) tensor([0.1651])
Downloading https://zenodo.org/record/6496656/files/retinamnist.npz?download=1 to /Users/naomileow/.medmnist/retinamnist.npz


  0%|          | 0/3291041 [00:00<?, ?it/s]

retinamnist tensor([0.3984, 0.2447, 0.1558]) tensor([0.2952, 0.1970, 0.1470])
Downloading https://zenodo.org/record/6496656/files/breastmnist.npz?download=1 to /Users/naomileow/.medmnist/breastmnist.npz


  0%|          | 0/559580 [00:00<?, ?it/s]

breastmnist tensor([0.3276]) tensor([0.2027])
Downloading https://zenodo.org/record/6496656/files/bloodmnist.npz?download=1 to /Users/naomileow/.medmnist/bloodmnist.npz


  0%|          | 0/35461855 [00:00<?, ?it/s]

bloodmnist tensor([0.7943, 0.6597, 0.6962]) tensor([0.2930, 0.3292, 0.1541])
Downloading https://zenodo.org/record/6496656/files/tissuemnist.npz?download=1 to /Users/naomileow/.medmnist/tissuemnist.npz


  0%|          | 0/124962739 [00:00<?, ?it/s]

tissuemnist tensor([0.1020]) tensor([0.4443])
Downloading https://zenodo.org/record/6496656/files/organamnist.npz?download=1 to /Users/naomileow/.medmnist/organamnist.npz


  0%|          | 0/38247903 [00:00<?, ?it/s]

organamnist tensor([0.4678]) tensor([0.6105])
Downloading https://zenodo.org/record/6496656/files/organcmnist.npz?download=1 to /Users/naomileow/.medmnist/organcmnist.npz


  0%|          | 0/15527535 [00:00<?, ?it/s]

organcmnist tensor([0.4932]) tensor([0.3762])
Downloading https://zenodo.org/record/6496656/files/organsmnist.npz?download=1 to /Users/naomileow/.medmnist/organsmnist.npz


  0%|          | 0/16528536 [00:00<?, ?it/s]

organsmnist tensor([0.4950]) tensor([0.3779])


  0%|          | 0/32657407 [00:00<?, ?it/s]

In [1]:
from utils.device import get_device
from models.backbone.datasets import MEAN_STDS, DataSets
from models.backbone.trainer import Trainer

device = get_device()
MODEL_SAVE_PATH = 'models/backbone/pretrained'

In [3]:
import torchvision
import torch
from torchvision.models import ResNet50_Weights
from torch import nn

retina_ds = DataSets('retinamnist', MEAN_STDS) # 5 classes
backbone = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
backbone.fc = nn.Linear(2048, len(retina_ds.info['label']))

batch_size = 256
rtrainer = Trainer(backbone, retina_ds, batch_size, device, balance=False)

Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz


In [None]:
rtrainer.run_train(5)

In [20]:
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

(0.52, 0.7557038049792416, 1.3845270824432374)


In [None]:
import os

# (0.5125, 0.7247683647010932, 1.464751205444336) for non pretrained, non balanced
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

torch.save(rtrainer.best_model.state_dict(), os.path.join(MODEL_SAVE_PATH, 'retina_backbone_pretrained_bal.pkl'))

In [10]:
import torchvision

# 'chestmnist', 'pneumoniamnist', 'octmnist',  'retinamnist'
chest_ds = DataSets('chestmnist', MEAN_STDS)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), pretrained=False)
# patch for single channel
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

batch_size = 128
ctrainer = Trainer(backbone, chest_ds, batch_size, device)

ctrainer.run_train(30)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [10]:
import torchvision
import os
import torch

# [7996,  1950,  9261, 13914,  3988,  4375,   978,  3705,  3263, 1690,  1799,  1158,  2279,   144]
chest_ds = DataSets('chestmnist', mean_stds)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), weights=None)
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, 'cxr_backbone.pkl')))

batch_size = 256
ctrainer = Trainer(backbone, chest_ds, batch_size, device, balance=True)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [None]:
ctrainer.run_train(5)

In [14]:
ctrainer.run_eval(ctrainer.model, ctrainer.test_loader)

(0.9239799784755875, 0.6329840270919405, 0.7261211995067128)

In [None]:
# Load and save pretrained resnet from medclip
import torch
from torch import nn

from medclip import MedCLIPModel, MedCLIPVisionModel

# load MedCLIP-ResNet50
model = MedCLIPModel(vision_cls=MedCLIPVisionModel)
model.from_pretrained()

bconv_weight = model.vision_model.model.conv1.weight.mean(dim=1).unsqueeze(1)

# The resnet model was trained on CheXpert and MIMIC-CXR
backbone = model.vision_model.model
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.conv1.weight = nn.Parameter(bconv_weight)

torch.save(backbone.state_dict(), 'medclip_resnet50.pkl')

In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler

configs = {
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*10
query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
device = get_device()

In [3]:
import torchvision
import os
import torch
from torch import nn
from models.backbone.trainer import DSTrainer
from utils.f1_loss import BalAccuracyLoss, MCCLoss, F1Loss
from models.embedding.model import load_medclip_retrained_resnet, load_pretrained_resnet

# backbone = torchvision.models.resnet50(num_classes=len(config.classes_split_map['train']), weights=None)
# backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone = load_pretrained_resnet(1, 512, 'models/backbone/pretrained/medclip_resnet50.pkl', False)
backbone.fc = nn.Linear(2048, len(config.classes_split_map['train']), bias=False)

mtrainer = DSTrainer(backbone, query_dataset.class_labels(), criterion=BalAccuracyLoss(), device=device)

In [4]:
mtrainer.run_train(9, support_loader, query_loader, lr=1e-4, weight_decay=1e-2)

Batch 1: loss 0.5119678974151611
Batch 2: loss 0.5130165815353394
Batch 3: loss 0.5083330869674683
Batch 4: loss 0.5047162771224976
Batch 5: loss 0.5002678632736206
Batch 6: loss 0.4915280342102051
Batch 7: loss 0.48292696475982666
Batch 8: loss 0.4732884168624878
Batch 9: loss 0.4661235809326172
Batch 10: loss 0.4513899087905884
Batch 11: loss 0.4456549286842346
Batch 12: loss 0.4385819435119629
Batch 13: loss 0.40876710414886475
Batch 14: loss 0.41575318574905396
Batch 15: loss 0.3964570164680481
Batch 16: loss 0.3955634832382202
Batch 17: loss 0.38344401121139526
Batch 18: loss 0.3853582739830017
Batch 19: loss 0.37931305170059204
Batch 20: loss 0.3641212582588196
Batch 21: loss 0.37644582986831665
Batch 22: loss 0.36176663637161255
Batch 23: loss 0.3737676739692688
Batch 24: loss 0.3661550283432007
Batch 25: loss 0.3574559688568115
Batch 26: loss 0.3555760383605957
Batch 27: loss 0.35296630859375
Batch 28: loss 0.3469029664993286
Batch 29: loss 0.3210749626159668
Batch 30: loss 0.3

  output, inverse_indices, counts = torch._unique2(
  denom[denom == 0.0] = 1


Epoch 1: Validation loss 0.34194380044937134 | F1 0.5965440750122071 | AUC 0.7955255076626198 | Acc H-Mean 0.6950270412651609
Batch 1: loss 0.06137770414352417
Batch 2: loss 0.07616311311721802
Batch 3: loss 0.050753891468048096
Batch 4: loss 0.059242069721221924
Batch 5: loss 0.05802994966506958
Batch 6: loss 0.06616359949111938
Batch 7: loss 0.07716786861419678
Batch 8: loss 0.05368441343307495
Batch 9: loss 0.05364042520523071
Batch 10: loss 0.060742199420928955
Batch 11: loss 0.06535154581069946
Batch 12: loss 0.05269354581832886
Batch 13: loss 0.04498434066772461
Batch 14: loss 0.074970543384552
Batch 15: loss 0.062256693840026855
Batch 16: loss 0.05268007516860962
Batch 17: loss 0.06480079889297485
Batch 18: loss 0.06918841600418091
Batch 19: loss 0.055506765842437744
Batch 20: loss 0.056970953941345215
Batch 21: loss 0.04814988374710083
Batch 22: loss 0.05216604471206665
Batch 23: loss 0.04417693614959717
Batch 24: loss 0.05861949920654297
Batch 25: loss 0.05966430902481079
Batc

In [15]:
mtrainer.run_eval(mtrainer.model, query_loader, True)

Loss 0.4831690788269043 | F1 0.46632564067840576 | AUC 0.77686636362777 | Specificity 0.8896466493606567 | Recall 0.4324343800544739 | Bal Acc 0.6610405445098877 | Acc 0.7259999513626099
Loss 0.47481298446655273 | F1 0.4678463339805603 | AUC 0.7806188292567595 | Specificity 0.8718984723091125 | Recall 0.45877307653427124 | Bal Acc 0.6653357744216919 | Acc 0.7379999756813049
Loss 0.4411987066268921 | F1 0.49595433473587036 | AUC 0.7779539128360795 | Specificity 0.8714594841003418 | Recall 0.4803504943847656 | Bal Acc 0.6759049892425537 | Acc 0.7439999580383301
Loss 0.47626906633377075 | F1 0.47104692459106445 | AUC 0.7963559349410614 | Specificity 0.8575273752212524 | Recall 0.47236400842666626 | Bal Acc 0.6649457216262817 | Acc 0.7269999980926514
Loss 0.4233925938606262 | F1 0.5393682718276978 | AUC 0.8172842148198833 | Specificity 0.8938779830932617 | Recall 0.5020822882652283 | Bal Acc 0.6979801654815674 | Acc 0.7700000405311584


(0.45976848602294923,
 0.48810830116271975,
 0.7898158510963108,
 0.876881992816925,
 0.46920084953308105)

In [14]:
mtrainer.run_eval(mtrainer.best_model, query_loader, True)

Loss 0.3236452341079712 | F1 0.6388580203056335 | AUC 0.8052553663786899 | Specificity 0.8051336407661438 | Recall 0.6420743465423584 | Bal Acc 0.7236039638519287 | Acc 0.7380000352859497
Loss 0.35609865188598633 | F1 0.5631943941116333 | AUC 0.7861629817423941 | Specificity 0.8200464248657227 | Recall 0.5798979997634888 | Bal Acc 0.6999722123146057 | Acc 0.7230000495910645
Loss 0.3331189751625061 | F1 0.6214109659194946 | AUC 0.819008560788822 | Specificity 0.8354839086532593 | Recall 0.6005048155784607 | Bal Acc 0.7179943323135376 | Acc 0.7569999694824219
Loss 0.347114622592926 | F1 0.5806185603141785 | AUC 0.7836986628788286 | Specificity 0.8207861185073853 | Recall 0.5990288257598877 | Bal Acc 0.7099074721336365 | Acc 0.7280000448226929
Loss 0.3335912227630615 | F1 0.5932779312133789 | AUC 0.7935558433293111 | Specificity 0.8327345848083496 | Recall 0.6035705208778381 | Bal Acc 0.7181525230407715 | Acc 0.7409999966621399


(0.3387137413024902,
 0.5994719743728638,
 0.797536283023609,
 0.8228369355201721,
 0.6050153017044068)

In [8]:
torch.save(mtrainer.best_model, 'models/backbone/pretrained/vindr2/trained-backbone-ba.pth')

In [None]:
import torchvision
import os
import torch
from torch import nn
from models.backbone.trainer import DSTrainer
from utils.f1_loss import BalAccuracyLoss, MCCLoss, F1Loss
from models.embedding.model import load_medclip_retrained_resnet, load_pretrained_resnet

backbone = torch.load('models/backbone/pretrained/vindr2/trained-backbone-ba1.pth')

mtrainer = DSTrainer(backbone, query_dataset.class_labels(), criterion=BalAccuracyLoss(), device=device)

In [None]:
mtrainer.run_train(7, support_loader, query_loader, lr=1e-4, weight_decay=1e-2)

In [None]:
torch.save(mtrainer.best_model, 'models/backbone/pretrained/vindr2/trained-backbone-ba1.pth')

In [2]:
from torch import nn

from models.backbone.trainer import DSTrainer
from utils.f1_loss import BalAccuracyLoss

backbone = torch.load('models/backbone/pretrained/vindr2/pretrained-ft.pth')

mtrainer = DSTrainer(backbone, query_dataset.class_labels(), criterion=BalAccuracyLoss(), device=device)

In [None]:
# Total 13 epochs
mtrainer.run_train(3, support_loader, query_loader, lr=1e-4, min_lr=1e-5, weight_decay=1e-2)

In [6]:
torch.save(mtrainer.best_model, 'models/backbone/pretrained/vindr2/pretrained-ft.pth')

In [3]:
mtrainer.run_eval(mtrainer.model, query_loader, True)

  output, inverse_indices, counts = torch._unique2(
  denom[denom == 0.0] = 1


Loss 0.3752140998840332 | F1 0.5579749345779419 | AUC 0.6818722377061905 | Specificity 0.6007749438285828 | Recall 0.6888378858566284 | Bal Acc 0.6448063850402832 | Raw Acc 0.6240000128746033
Loss 0.3561440706253052 | F1 0.5730926394462585 | AUC 0.7209999003372558 | Specificity 0.6452262997627258 | Recall 0.6919309496879578 | Bal Acc 0.6685786247253418 | Raw Acc 0.6490000486373901
Loss 0.3474789261817932 | F1 0.5531376600265503 | AUC 0.7310684379819131 | Specificity 0.6689056754112244 | Recall 0.6753683686256409 | Bal Acc 0.6721370220184326 | Raw Acc 0.652999997138977
Loss 0.3710537552833557 | F1 0.5372604131698608 | AUC 0.6892123562738988 | Specificity 0.6180906295776367 | Recall 0.6586120128631592 | Bal Acc 0.638351321220398 | Raw Acc 0.6230000257492065
Loss 0.34730011224746704 | F1 0.5868090391159058 | AUC 0.7284682124691239 | Specificity 0.6394455432891846 | Recall 0.7057230472564697 | Bal Acc 0.6725842952728271 | Raw Acc 0.6539999842643738


(0.3594381928443909,
 0.5616549372673034,
 0.7103242289536764,
 0.6344886183738708,
 0.6840944528579712)

In [None]:
mtrainer.run_eval(mtrainer.best_model, query_loader, True)

# Loss 0.3633430004119873 | F1 0.5735126733779907 | AUC 0.7259596714303461 | Specificity 0.6672283411026001 | Recall 0.6480233073234558 | Bal Acc 0.6576257944107056 | Raw Acc 0.6639999747276306
# Loss 0.3383271098136902 | F1 0.5868979096412659 | AUC 0.7238842103494072 | Specificity 0.6317911148071289 | Recall 0.7322815656661987 | Bal Acc 0.6820363402366638 | Raw Acc 0.6549999713897705
# Loss 0.36487895250320435 | F1 0.5721150636672974 | AUC 0.7073469603005058 | Specificity 0.6365436315536499 | Recall 0.6813758611679077 | Bal Acc 0.6589597463607788 | Raw Acc 0.6499999761581421
# Loss 0.3419176936149597 | F1 0.5767743587493896 | AUC 0.725889286229896 | Specificity 0.6792066097259521 | Recall 0.6816607713699341 | Bal Acc 0.6804336905479431 | Raw Acc 0.6769999861717224
# (0.35808582305908204,
#  0.5641344547271728,
#  0.7135804541562424,
#  0.6551853537559509,
#  0.6699150681495667)

Loss 0.3633430004119873 | F1 0.5735126733779907 | AUC 0.7259596714303461 | Specificity 0.6672283411026001 | Recall 0.6480233073234558 | Bal Acc 0.6576257944107056 | Raw Acc 0.6639999747276306
Loss 0.3383271098136902 | F1 0.5868979096412659 | AUC 0.7238842103494072 | Specificity 0.6317911148071289 | Recall 0.7322815656661987 | Bal Acc 0.6820363402366638 | Raw Acc 0.6549999713897705
Loss 0.36487895250320435 | F1 0.5721150636672974 | AUC 0.7073469603005058 | Specificity 0.6365436315536499 | Recall 0.6813758611679077 | Bal Acc 0.6589597463607788 | Raw Acc 0.6499999761581421
Loss 0.3419176936149597 | F1 0.5767743587493896 | AUC 0.725889286229896 | Specificity 0.6792066097259521 | Recall 0.6816607713699341 | Bal Acc 0.6804336905479431 | Raw Acc 0.6769999861717224
Loss 0.3819623589515686 | F1 0.5113722681999207 | AUC 0.684822142471056 | Specificity 0.6611570715904236 | Recall 0.6062338352203369 | Bal Acc 0.6336954832077026 | Raw Acc 0.6240000128746033


(0.35808582305908204,
 0.5641344547271728,
 0.7135804541562424,
 0.6551853537559509,
 0.6699150681495667)