In [1]:
import pandas as pd
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop
from transformers import BertTokenizer, BertModel, AdamW, BertForSequenceClassification

if torch.cuda.is_available():      
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Quadro RTX 8000


In [2]:
from torchvision import transforms
from transformers import AutoTokenizer, AutoImageProcessor
import re
import string
import torch.nn as nn
from sklearn.metrics import f1_score
import torchvision.models as models
from transformers import BertModel
from skimage import io

In [3]:
class MIMICDataset(Dataset):
    def __init__(self, csv_path):
        # process image and text seperately
        self.df = pd.read_csv(csv_path)
        self.PRED_LABEL = [
            'Atelectasis',
            'Cardiomegaly', 
            'Consolidation',
            'Edema',
            'Enlarged Cardiomediastinum',
            'Fracture',
            'Lung Lesion',
            'Lung Opacity',
            'No Finding',
            'Pleural Effusion',
            'Pleural Other',
            'Pneumonia',
            'Pneumothorax',
            'Support Devices']
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        # Load text
        text_path = self.df.iloc[idx]['text_path']
        text_path = "/scratch/tg2426/MIMIC_CLIP/" + text_path
        with open(text_path, 'r') as f:
            text = f.read()
        # clean special character, \n and extra space
        clean_text = re.sub('[\\(\[#.!?,\'\/\])0-9]', ' ', str(text))
        clean_text = clean_text.replace('\n', ' ').replace('\r', '')
        clean_text = ' '.join(clean_text.split())
        encoding = self.tokenizer(
            clean_text,
            add_special_tokens=True,
            max_length=512,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )
        # Load label
        label = torch.FloatTensor(np.zeros(len(self.PRED_LABEL), dtype=float))
        for i in range(0, len(self.PRED_LABEL)):
            if (self.df[self.PRED_LABEL[i].strip()].iloc[idx].astype('float') > 0):
                label[i] = self.df[self.PRED_LABEL[i].strip()].iloc[idx].astype('float')
        
        #Prepare inputs for CLIP model
        
        return (encoding['input_ids'].squeeze(0).clone().detach() , encoding['attention_mask'].squeeze(0).clone().detach(), label)

In [4]:
######################## load csv file, and dataloader######################
BATCH_SIZE = 16
train_loader = DataLoader(MIMICDataset('/'), batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(MIMICDataset(''), batch_size=BATCH_SIZE, shuffle=True) 
test_loader = DataLoader(MIMICDataset(''), batch_size=BATCH_SIZE, shuffle=True) 

In [6]:
def train(model, optimizer, criterion, train_loader, valid_loader, num_epochs=5):
    train_loss_list, train_acc_list, valid_loss_list, valid_acc_list = [], [], [], []
    for epoch in range(num_epochs):
        train_loss, train_acc = 0, 0
        model.train()
        for (input_ids, attention_mask, labels) in train_loader:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(input_ids, attention_mask=attention_mask)
            output = output['logits']
#             print(output)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * input_ids.size(0)
            predicted = torch.round(output)
            train_acc += (predicted == labels).sum().item()

        train_loss /= len(train_loader.dataset)
        train_acc /= (len(train_loader.dataset)*14)
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        
        #### valid ####
        val_loss, val_acc = 0, 0
        model.eval()
        for (input_ids, attention_mask,labels) in valid_loader:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            labels = labels.to(device)
            output = model(input_ids, attention_mask=attention_mask)
            output = output['logits']
            predicted = torch.round(output)
            loss = criterion(output, labels)
            val_loss += loss.item()*input_ids.size(0)
            val_acc += (predicted == labels).sum().item()
      
        val_loss /= len(valid_loader.dataset)
        val_acc /= (len(valid_loader.dataset)*14)
        valid_loss_list.append(val_loss)
        valid_acc_list.append(val_acc)    
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f} \tTrain Accuracy: {:.6f} \tValidation Accuracy: {:.6f} '.format(
            epoch, train_loss, val_loss, train_acc, val_acc))
#         torch.save(model.state_dict(), "./clip_test_result/model/" + 'clip_classify{}.pt'.format(epoch + 1))
    return train_loss_list, train_acc_list, valid_loss_list, valid_acc_list

In [7]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=14).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_loss_list, train_acc_list, valid_loss_list, valid_acc_list = train(model, optimizer, criterion, train_loader, valid_loader, num_epochs=5)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch: 0 	Training Loss: 4.095170 	Validation Loss: 4.110264 	Train Accuracy: 0.285153 	Validation Accuracy: 0.283143 
Epoch: 1 	Training Loss: 4.134522 	Validation Loss: 3.992084 	Train Accuracy: 0.283867 	Validation Accuracy: 0.247048 
Epoch: 2 	Training Loss: 4.094620 	Validation Loss: 4.005354 	Train Accuracy: 0.284316 	Validation Accuracy: 0.413571 
