In [2]:
import medmnist
from medmnist import INFO

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import numpy as np

# chestmnist, retinamnist
def get_image_mean_std(dataname):
    info = INFO[dataname]
    DataClass = getattr(medmnist, info['python_class'])

    transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((224, 224))
            ])

    train_dataset = DataClass(split='train', transform=transform, download=True)

    train_loader = data.DataLoader(dataset=train_dataset, batch_size=8192)
    total = info['n_samples']['train']
    mean = torch.zeros(info['n_channels'])
    std = torch.zeros(info['n_channels'])
    for images, _ in train_loader:
        num_img = len(images)
        m, s = images.mean([0,2,3]), images.std([0,2,3])
        mean += num_img * m / total
        std += np.sqrt(num_img/total) * s
    return mean, std

In [19]:
mean_stds = {}

for k in ['pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist']:
    mean, std = get_image_mean_std(k)
    print(k, mean, std)
    mean_stds[k] = {
        'mean': mean,
        'std': std
    }
print(mean_stds)

Using downloaded and verified file: /Users/naomileow/.medmnist/pathmnist.npz
pathmnist tensor([0.7405, 0.5330, 0.7058]) tensor([0.3920, 0.5636, 0.3959])
Downloading https://zenodo.org/record/6496656/files/chestmnist.npz?download=1 to /Users/naomileow/.medmnist/chestmnist.npz


  0%|          | 0/82802576 [00:00<?, ?it/s]

chestmnist tensor([0.4936]) tensor([0.7392])
Downloading https://zenodo.org/record/6496656/files/dermamnist.npz?download=1 to /Users/naomileow/.medmnist/dermamnist.npz


  0%|          | 0/19725078 [00:00<?, ?it/s]

dermamnist tensor([0.7631, 0.5381, 0.5614]) tensor([0.1354, 0.1530, 0.1679])
Downloading https://zenodo.org/record/6496656/files/octmnist.npz?download=1 to /Users/naomileow/.medmnist/octmnist.npz


  0%|          | 0/54938180 [00:00<?, ?it/s]

octmnist tensor([0.1889]) tensor([0.6606])
Downloading https://zenodo.org/record/6496656/files/pneumoniamnist.npz?download=1 to /Users/naomileow/.medmnist/pneumoniamnist.npz


  0%|          | 0/4170669 [00:00<?, ?it/s]

pneumoniamnist tensor([0.5719]) tensor([0.1651])
Downloading https://zenodo.org/record/6496656/files/retinamnist.npz?download=1 to /Users/naomileow/.medmnist/retinamnist.npz


  0%|          | 0/3291041 [00:00<?, ?it/s]

retinamnist tensor([0.3984, 0.2447, 0.1558]) tensor([0.2952, 0.1970, 0.1470])
Downloading https://zenodo.org/record/6496656/files/breastmnist.npz?download=1 to /Users/naomileow/.medmnist/breastmnist.npz


  0%|          | 0/559580 [00:00<?, ?it/s]

breastmnist tensor([0.3276]) tensor([0.2027])
Downloading https://zenodo.org/record/6496656/files/bloodmnist.npz?download=1 to /Users/naomileow/.medmnist/bloodmnist.npz


  0%|          | 0/35461855 [00:00<?, ?it/s]

bloodmnist tensor([0.7943, 0.6597, 0.6962]) tensor([0.2930, 0.3292, 0.1541])
Downloading https://zenodo.org/record/6496656/files/tissuemnist.npz?download=1 to /Users/naomileow/.medmnist/tissuemnist.npz


  0%|          | 0/124962739 [00:00<?, ?it/s]

tissuemnist tensor([0.1020]) tensor([0.4443])
Downloading https://zenodo.org/record/6496656/files/organamnist.npz?download=1 to /Users/naomileow/.medmnist/organamnist.npz


  0%|          | 0/38247903 [00:00<?, ?it/s]

organamnist tensor([0.4678]) tensor([0.6105])
Downloading https://zenodo.org/record/6496656/files/organcmnist.npz?download=1 to /Users/naomileow/.medmnist/organcmnist.npz


  0%|          | 0/15527535 [00:00<?, ?it/s]

organcmnist tensor([0.4932]) tensor([0.3762])
Downloading https://zenodo.org/record/6496656/files/organsmnist.npz?download=1 to /Users/naomileow/.medmnist/organsmnist.npz


  0%|          | 0/16528536 [00:00<?, ?it/s]

organsmnist tensor([0.4950]) tensor([0.3779])


  0%|          | 0/32657407 [00:00<?, ?it/s]

In [1]:
from utils.device import get_device
from models.backbone.datasets import MEAN_STDS, DataSets
from models.backbone.trainer import Trainer

device = get_device()
MODEL_SAVE_PATH = 'models/backbone/pretrained'

In [3]:
import torchvision
import torch
from torchvision.models import ResNet50_Weights
from torch import nn

retina_ds = DataSets('retinamnist', MEAN_STDS) # 5 classes
backbone = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
backbone.fc = nn.Linear(2048, len(retina_ds.info['label']))

batch_size = 256
rtrainer = Trainer(backbone, retina_ds, batch_size, device, balance=False)

Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/retinamnist.npz


In [None]:
rtrainer.run_train(5)

In [20]:
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

(0.52, 0.7557038049792416, 1.3845270824432374)


In [None]:
import os

# (0.5125, 0.7247683647010932, 1.464751205444336) for non pretrained, non balanced
print(rtrainer.run_eval(rtrainer.best_model, rtrainer.test_loader)) 

torch.save(rtrainer.best_model.state_dict(), os.path.join(MODEL_SAVE_PATH, 'retina_backbone_pretrained_bal.pkl'))

In [10]:
import torchvision

# 'chestmnist', 'pneumoniamnist', 'octmnist',  'retinamnist'
chest_ds = DataSets('chestmnist', MEAN_STDS)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), pretrained=False)
# patch for single channel
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

batch_size = 128
ctrainer = Trainer(backbone, chest_ds, batch_size, device)

ctrainer.run_train(30)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [10]:
import torchvision
import os
import torch

# [7996,  1950,  9261, 13914,  3988,  4375,   978,  3705,  3263, 1690,  1799,  1158,  2279,   144]
chest_ds = DataSets('chestmnist', mean_stds)

backbone = torchvision.models.resnet50(num_classes=len(chest_ds.info['label']), weights=None)
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, 'cxr_backbone.pkl')))

batch_size = 256
ctrainer = Trainer(backbone, chest_ds, batch_size, device, balance=True)

Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz
Using downloaded and verified file: /Users/naomileow/.medmnist/chestmnist.npz




In [None]:
ctrainer.run_train(5)

In [14]:
ctrainer.run_eval(ctrainer.model, ctrainer.test_loader)

(0.9239799784755875, 0.6329840270919405, 0.7261211995067128)

In [None]:
# Load and save pretrained resnet from medclip
import torch
from torch import nn

from medclip import MedCLIPModel, MedCLIPVisionModel

# load MedCLIP-ResNet50
model = MedCLIPModel(vision_cls=MedCLIPVisionModel)
model.from_pretrained()

bconv_weight = model.vision_model.model.conv1.weight.mean(dim=1).unsqueeze(1)

backbone = model.vision_model.model
backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
backbone.conv1.weight = nn.Parameter(bconv_weight)

torch.save(backbone.state_dict(), 'medclip_resnet50.pkl')


In [1]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
backbone = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
inputs = tokenizer(['Diabetic Retinopathy', 'Cataract'], return_tensors='pt', padding=True)
# t = backbone(**tokenizer('Diabetic Retinopathy'))
t = backbone(**inputs)

In [5]:
t

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[ 0.2592,  0.2502, -0.2490,  ...,  0.1130,  0.3787,  0.3324],
         [-0.0311,  0.3543, -0.0690,  ...,  0.9179,  0.0060,  0.3633],
         [-0.1416,  0.3260,  0.1543,  ...,  0.3879, -0.3474,  0.3119],
         ...,
         [-0.1302,  0.2944,  0.0212,  ...,  0.0955,  0.1206,  0.0150],
         [-0.0334, -0.0628,  0.1859,  ...,  0.4934,  0.3345,  0.3496],
         [ 1.2358,  0.3689, -0.7166,  ...,  0.0276,  0.1414,  0.4512]],

        [[ 0.2802,  0.2866, -0.4953,  ...,  0.0248,  0.3556,  0.3896],
         [ 0.2714,  0.4551, -0.2985,  ...,  0.9022, -0.1582,  0.4126],
         [-0.2106,  0.1317, -0.2367,  ...,  0.0690, -0.6939,  0.3129],
         ...,
         [-0.0426,  0.0199, -0.1960,  ...,  0.3997, -0.2091,  0.4242],
         [-0.0935,  0.0331, -0.0046,  ...,  0.4687, -0.3836,  0.0577],
         [-0.1167, -0.1394, -0.1648,  ...,  0.3966, -0.1343,  0.3803]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_ou