In [None]:
!pip install torch transformers sklearn

In [None]:
!pip install git+https://github.com/huggingface/transformers

In [1]:
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
import transformers
from PIL import Image
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Sequential, Linear, Sigmoid
from sklearn.metrics import accuracy_score

In [2]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

In [3]:
classes = ['NORMAL', 'PNEUMONIA']
class_mappings = { classes[0]: 0, classes[1]: 1 }
class_weights = {0: 1.94, 1: 0.67}

In [4]:
class XrayDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        
        self.images = []
        self.labels = []
        
        for c in classes:
            class_path = os.path.join(folder_path, c)
            
            for i in os.listdir(class_path):
                self.images.append(os.path.join(class_path, i))
                self.labels.append(class_mappings[c])
                    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        
        image = Image.open(image_path)\
                    .convert('RGB')\
                    .resize((640, 720))
        
        item = feature_extractor(images=image, return_tensors="pt")
        item['pixel_values'] = item['pixel_values']
        
        item['label'] = torch.Tensor([self.labels[idx]])
        
        return item
    
    def __len__(self):
        return len(self.labels)
    
    def shuffle(self):
        order = np.random.permutation(len(self))
        
        self.images = [self.images[i] for i in order]
        self.labels = [self.labels[i] for i in order]
    
train_ds = XrayDataset('chest_xray/train')
test_ds = XrayDataset('chest_xray/test')

train_ds.shuffle()
test_ds.shuffle()

In [5]:
class BinaryViT(torch.nn.Module):
    def __init__(self, hidden_dropout_prob):
        super(BinaryViT, self).__init__()

        self.pre_trained_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
        self.pre_trained_model.config.hidden_dropout_prob = hidden_dropout_prob
        
        for param in self.pre_trained_model.vit.parameters():
            param.requires_grad = False
            
        self.pre_trained_model.classifier = Sequential(
            torch.nn.Dropout(p=hidden_dropout_prob),
            Linear(in_features=768, out_features=1, bias=True)
        )

    def forward(self, pixel_values, labels=None):
        return self.pre_trained_model(pixel_values)
    
model = BinaryViT(0.35).cuda()

In [6]:
def data_collator(features):
    batch = {}
    
    labels = []
    pixel_values = []
    
    for f in features:
        labels.append(f['label'])
        pixel_values.append(f['pixel_values'])
    
    batch['labels'] = torch.stack(labels)
    batch['pixel_values'] = torch.cat(pixel_values)
    
    return batch


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    return {'accuracy': accuracy_score(y_true=labels, y_pred=np.round(logits, 0))}


class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        weights = torch.Tensor([ class_weights[l[0]] for l in labels.tolist() ]).cuda()
        
        outputs = model(**inputs)
        logits = outputs.logits
        
        x = torch.sigmoid(logits)
        
        loss_fct = torch.nn.BCELoss(weight=weights)
        loss = loss_fct(x.view(-1), labels.float().view(-1))

        return (loss, {'logits': x.view(-1)}) if return_outputs else loss

In [7]:
training_args = TrainingArguments(
    "vision_transformer_checkpoints", 
    overwrite_output_dir=True,
    evaluation_strategy="epoch", 
    logging_strategy='epoch',
    dataloader_pin_memory=True, 
    per_device_train_batch_size=64,
    num_train_epochs=10.0,
    per_device_eval_batch_size=32,
    load_best_model_at_end=True,
    metric_for_best_model='eval_accuracy',
    learning_rate=5e-4,
    lr_scheduler_type='cosine'
)

trainer = MyTrainer(
    model=model, 
    args=training_args, 
    train_dataset=train_ds, 
    data_collator = data_collator,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

# add lr scheduler

In [9]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.4941,0.518921,0.8125
2,0.3166,0.50642,0.834936
3,0.2645,0.495775,0.842949
4,0.2392,0.520427,0.831731
5,0.2241,0.565012,0.81891
6,0.2211,0.522782,0.826923
7,0.2123,0.503131,0.841346
8,0.2099,0.51869,0.828526
9,0.208,0.517521,0.828526
10,0.2037,0.517021,0.828526


TrainOutput(global_step=820, training_loss=0.25934281232880385, metrics={'train_runtime': 4040.6114, 'train_samples_per_second': 12.909, 'train_steps_per_second': 0.203, 'total_flos': 0.0, 'train_loss': 0.25934281232880385, 'epoch': 10.0})