!pip install -qqq transformers medmnist

In [1]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
from torch import nn, optim
from torch.utils import data
import torch.nn.functional as F
from torchvision import datasets, transforms, models

import medmnist
from medmnist import INFO, Evaluator
from numpy.random import RandomState
from torch.utils.data import Subset
import re

In [2]:
def train(model, device, train_loader, optimizer, epoch, display=True):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(torch.float32).to(device), target.to(torch.float32).to(device)
        optimizer.zero_grad()
        output = model(data, labels=target)
        loss = output.loss # F.binary_cross_entropy_with_logits(output.logits, target)
        loss.backward()
        optimizer.step()
    if display:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
    return loss.item()


def test(model, device, test_loader, name="\nVal", get_loss=False):
    model.eval()
    test_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(torch.float32).to(device)
            output = model(data, labels=target)
            test_loss += output.loss.item() # F.binary_cross_entropy_with_logits(output.logits, target, reduction='sum').item()  # sum up batch loss
            pred = output.logits > 0.5
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    if get_loss:
        return test_loss
    print('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)

In [3]:
from transformers import ViTForImageClassification, ViTConfig, ViTFeatureExtractor

In [4]:
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
np.random.seed(0)

# preprocessing
data_flag = 'pneumoniamnist'
download = True

info = INFO[data_flag]
n_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])

In [5]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

normalize = transforms.Normalize(mean=[.5], std=[.5])
grayToRgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1) )

train_transforms = transforms.Compose(
        [
            transforms.RandAugment(),
            transforms.RandomResizedCrop(feature_extractor.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            grayToRgb
        ]
    )

val_transforms = transforms.Compose(
        [
            transforms.Resize(feature_extractor.size),
            transforms.CenterCrop(feature_extractor.size),
            transforms.ToTensor(),
            normalize,
            grayToRgb
        ]
    )



data_transforms = transforms.Compose([
      transforms.Resize(224),
      transforms.ToTensor(),
      normalize,
      grayToRgb
      ])

In [6]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# load the data
train_dataset = DataClass(split='train', transform=data_transforms, download=download)
val_dataset = DataClass(split='train', transform=data_transforms, download=download)

accs_val = []

for seed in  range(1, 51):
  prng = RandomState(seed)
  random_permute = prng.permutation(np.arange(0, 1000))
  train_top = 10//n_classes
  val_top = 1000//n_classes
  indx_train = np.concatenate([np.where(train_dataset.labels == label)[0][random_permute[0:train_top]] for label in range(0, n_classes)])
  indx_val = np.concatenate([np.where(train_dataset.labels == label)[0][random_permute[train_top:train_top + val_top]] for label in range(0, n_classes)])

  train_data = Subset(train_dataset, indx_train)
  val_data = Subset(val_dataset, indx_val)

  print('Num Samples For Training %d Num Samples For Val %d'%(train_data.indices.shape[0],val_data.indices.shape[0]))

  train_loader = torch.utils.data.DataLoader(train_data,
                                             batch_size=32, 
                                             shuffle=True)

  val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=128, 
                                             shuffle=False)
#   model = models.alexnet(pretrained=True)
#   model.classifier = nn.Linear(256 * 6 * 6, 1)

  # Initializing a model from the vit-base-patch16-224 style configuration
  model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=1
  )
  model.to(device).train() 
#   optimizer = optim.Adam(model.classifier.parameters(),lr=1e-3, weight_decay=0.005)
#   optimizer = optim.AdamW(model.classifier.parameters(), lr=1e-3)
  optimizer = optim.SGD(model.classifier.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.1)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
#   scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.99)
  losses_train = []
  for epoch in range(200):
    train_loss = train(model, device, train_loader, optimizer, epoch, display=epoch%5==0)
    losses_train.append(train_loss)
#     scheduler.step()
    if epoch%10 == 0:
        scheduler.step(test(model, device, val_loader, get_loss=True))

  accs_val.append(test(model, device, val_loader))

  plt.plot(losses_train, label=seed)
  plt.xlabel('epoch')
  plt.ylabel('loss')
  plt.title(f'Training loss over 1000 epochs for ViT with binary_cross_entropy_with_logits')
  plt.legend()
  plt.savefig(f'img/vit_converge_seed{seed}.png')

accs_val = np.array(accs_val)

print('Val acc over %d instances on dataset: %s %.2f +- %.2f'%(len(accs_val), data_flag, accs_val.mean(), accs_val.std()))

Using downloaded and verified file: /home/z_yuxian/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /home/z_yuxian/.medmnist/pneumoniamnist.npz
Num Samples For Training 10 Num Samples For Val 1000


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.




KeyboardInterrupt: 