In [3]:
import requests
import os
import logging
import gdown
import random

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from torchvision.datasets import VOCSegmentation
import torchmetrics
import torchvision
import albumentations as A

import re
import string
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import cv2
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as T
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from scipy.io import loadmat
from sklearn.manifold import TSNE
from torchmetrics.classification import MulticlassF1Score, JaccardIndex, MulticlassPrecision, MulticlassRecall, MulticlassAveragePrecision
import pandas as pd
from torchinfo import torchinfo

from transformers import ConvNextV2Model, BertModel, BertTokenizer, ViTModel, ViTConfig
from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, CLIPModel, CLIPTokenizer, CLIPProcessor
from transformers import DeiTConfig, DeiTFeatureExtractor, DeiTImageProcessor, DeiTModel
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import timm
from transformers import AlignModel, AlignProcessor, AlignConfig, AlignVisionConfig, AlignTextConfig
from transformers import CLIPModel, CLIPProcessor, CLIPConfig, CLIPTextConfig, CLIPVisionConfig, CLIPImageProcessor

from torchmetrics.functional import pairwise_cosine_similarity
from torchmetrics.classification import MultilabelAccuracy

In [4]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
logging.basicConfig(level=logging.ERROR)

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda:0")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
    
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")




There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 3090


In [5]:
ROOT_DIR = '../Datasets/ocular-disease-recognition-odir5k/'

## ALIGN

In [6]:
BATCH_SIZE = 128

In [7]:
CSV_PATH = ROOT_DIR + 'dataset_single_eye.csv'
TEST_CSV = ROOT_DIR + 'TESTING_dataset_single_eye.csv'
IMG_PATH = ROOT_DIR + 'preprocessed_images/'

In [8]:
torchvision.io.read_image(IMG_PATH + '0_left.jpg').shape

torch.Size([3, 512, 512])

In [9]:
train_val_df = pd.read_csv(CSV_PATH)
test_df = pd.read_csv(TEST_CSV)

In [10]:
# keyword_map = {'N' : 'normal',
#                'D' : 'diabetes',
#                'G' : 'glaucoma',
#                'C' : 'cataract',
#                'A' : 'age-related-macular-degeneration',
#                'H' : "hypertension",
#                'M' : 'myopia',
#                'O' : 'other defects'}

# def reconstruct_kw(row, kw_map : dict = keyword_map):
#     diagnosis = ""
#     for kw in kw_map.keys():
#         if row[kw] == 1:
#             diagnosis+= kw_map[kw] + " "
#     diagnosis = diagnosis.strip()
#     return diagnosis

# train_val_df['Keywords'] = train_val_df.apply(lambda x : reconstruct_kw(x), axis = 1)
# test_df['Keywords'] = test_df.apply(lambda x : reconstruct_kw(x), axis = 1)


In [11]:
def preprocess_text(df:pd.DataFrame):
    df['Keywords'] = df['Keywords'].str.lower()
    df['Keywords'] = df['Keywords'].apply(lambda x: " ".join(x.split()))
    df['Keywords'] = df['Keywords'].apply(lambda x: re.sub('[%s]' % re.escape(string.punctuation), '' , x))
    return df
train_val_df = preprocess_text(train_val_df)
test_df = preprocess_text(test_df)

In [12]:
np.max(train_val_df['Keywords'].apply(lambda x: len(x.split())))

10

In [13]:
train_df, val_df = train_test_split(train_val_df, test_size = 0.15, random_state= 123456)
len(train_df), len(val_df)

(4877, 861)

In [14]:
IMG_SIZE = (224, 224)

rescale_transform = torchvision.transforms.Compose([
    #torchvision.transforms.CenterCrop(IMG_SIZE),
    torchvision.transforms.Resize(IMG_SIZE, antialias = False, interpolation = torchvision.transforms.InterpolationMode.NEAREST),
    torchvision.transforms.Normalize(
        timm.data.constants.IMAGENET_DEFAULT_MEAN,
        timm.data.constants.IMAGENET_DEFAULT_STD
    )
])

augmentation = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(p = 0.5),
    torchvision.transforms.RandomVerticalFlip(p= 0.5),
    #torchvision.transforms.RandomRotation(90)
])

In [15]:
processor = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
#torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')
extractor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32')

In [16]:
processor(text = 'diabetic retinopathy', padding = True, max_length = 100, truncation = True, return_tensors = 'pt')

{'input_ids': tensor([[49406, 30230, 31270, 28466, 49407]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [17]:
class ODIRDatasetMM(Dataset) :
    def __init__(self, df, IMG_FOLDER, tokenizer = processor, feature_extractor = extractor, augmentation = None) :
        '''
        id : list of samples ids as string
        '''
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.df = df
        #self.text = [tokenizer(text = x, padding = True, max_length = 100, truncation = True, return_tensors = 'pt') for x in df['Keywords']]
        self.eye = df['eye']
        self.labels = torch.tensor(df[['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']].to_numpy()).float()
        self.img_dir = [IMG_PATH + x for x in df['Image']]

        self.augmentation = augmentation

        self.images = [extractor(Image.open(x), return_tensors='pt')['pixel_values'][0] for x in self.img_dir]
        #self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        texts = self.df.iloc[idx]['Keywords']
        batch_imgs = self.images[idx]
        if(self.augmentation is not None):
            batch_imgs = self.augmentation(batch_imgs)
        return batch_imgs, texts, self.labels[idx]

In [None]:
# train_dataset = ODIRDatasetMM(train_df, IMG_PATH)#, augmentation = augmentation)
# val_dataset   = ODIRDatasetMM(val_df, IMG_PATH)
# test_dataset  = ODIRDatasetMM(test_df, IMG_PATH)

In [None]:
# with open("datasets.trch", 'wb') as f:
#     torch.save([train_dataset, val_dataset, test_dataset], f)

In [19]:
with open("datasets.trch", 'rb') as f:
    train_dataset, val_dataset, test_dataset = torch.load(f)

In [20]:
train_dataloader = DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size= BATCH_SIZE)

In [21]:
# contrastive learning on training data finetuning

In [22]:

#odict_keys(['loss', 'logits_per_image', 'logits_per_text', 'text_embeds', 'image_embeds', 'text_model_output', 'vision_model_output'])

In [23]:
class ClassificationNet(nn.Module):
    def __init__(self, dims = 640, drop_prob = 0.5):
        super().__init__()
        self.layer_1 = nn.Linear(dims, 768)
        self.layer_2 = nn.Linear(768, 768)
        self.layer_3 = nn.Linear(768, dims)
        
        self.dropout = nn.Dropout(drop_prob)
        
    def forward(self, input):
        resi = input
        out = self.dropout(input)
        out = F.relu(self.layer_1(out))
        out = self.dropout(out)
        out = F.relu(self.layer_2(out))
        out = self.dropout(out)
        out = F.relu(self.layer_3(out))
        out = out + resi
        return out


In [24]:
class ContrastiveLearning(nn.Module):
    def __init__(self):
        super().__init__()
        self.align = CLIPModel(CLIPConfig())
        self.text_net = ClassificationNet(dims = 512, drop_prob=0.5)
        self.text_head = nn.Linear(512, 8)
        self.img_net = ClassificationNet(dims = 512, drop_prob=0.5)
        self.img_head = nn.Linear(512, 8)
    
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, pixel_values, input_ids, attention_mask):
        out = self.align(pixel_values = pixel_values, input_ids = input_ids, attention_mask = attention_mask, return_loss = True)

        img_outs = out['image_embeds']
        # img_outs = self.dropout(img_outs)
        img_outs = F.sigmoid(self.img_head(img_outs))

        txt_outs = out['text_embeds']
        # txt_outs = self.dropout(txt_outs)
        txt_outs = F.sigmoid(self.text_head(txt_outs))

        return out, img_outs, txt_outs

In [25]:
model= ContrastiveLearning().to(device)

In [26]:
torch.cuda.empty_cache()
weights = torch.tensor([1,1.2,1.5,1.5,1.5,1.5, 1.5, 1.2]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-6)
train_img_acc = MultilabelAccuracy(8, average = 'micro').to(device)
train_txt_acc = MultilabelAccuracy(8, average = 'micro').to(device)
val_img_acc = MultilabelAccuracy(8, average = 'micro').to(device)
val_txt_acc = MultilabelAccuracy(8, average = 'micro').to(device)

img_loss_fn = nn.BCELoss(weights)
txt_loss_fn = nn.BCELoss(weights)

EPOCHS = 100

TOTAL_COUNT = 5
IMG_TRAIN_COUNTER = TOTAL_COUNT
IMG_TRAIN_COUNTER_INIT = TOTAL_COUNT
MINI_EPOCHS = TOTAL_COUNT

for epoch_num in range(EPOCHS):

  total_acc_train = 0
  total_loss_train = 0

  for train_image, train_text, train_label in tqdm(train_dataloader):
        optimizer.zero_grad()
        train_label = train_label.to(device)
        train_image = train_image.to(device)
        train_text = processor(train_text, padding = True, max_length = 100, truncation = True, return_tensors = 'pt')

        # print(train_text['attention_mask'].shape)
        # break
        mask = train_text['attention_mask'].to(device)
        input_id = train_text['input_ids'].to(device)
        output, img_outs, txt_outs = model.forward(pixel_values = train_image, input_ids = input_id, attention_mask = mask)

        batch_loss = 5*output['loss'] + 5*txt_loss_fn(txt_outs, train_label) + .25*img_loss_fn(img_outs, train_label) 
        batch_loss.backward()
        optimizer.step()
        total_loss_train += batch_loss.item()
      
        train_img_acc(img_outs, train_label)
        train_txt_acc(txt_outs, train_label)

  # IMG_TRAIN_COUNTER-=1

  # # TRAINING LOOP FOR OTHER LAYERS
  # if(IMG_TRAIN_COUNTER == 0):
  #   model.align.requires_grad_(False)
  #   IMG_TRAIN_COUNTER = 10
  #   mini_optim = torch.optim.Adam(model.parameters(), lr = 0.00005)
  #   for i in range(MINI_EPOCHS):
  #     print("-", end = "")
  #     for train_image, train_text, train_label in train_dataloader:
  #       mini_optim.zero_grad()
  #       train_label = train_label.to(device)
  #       train_image = train_image.to(device)
  #       train_text = processor(train_text, padding = True, max_length = 100, truncation = True, return_tensors = 'pt')

  #       # print(train_text['attention_mask'].shape)
  #       # break
  #       mask = train_text['attention_mask'].to(device)
  #       input_id = train_text['input_ids'].to(device)
  #       output, img_outs, txt_outs = model.forward(pixel_values = train_image, input_ids = input_id, attention_mask = mask)

  #       batch_loss = img_loss_fn(img_outs, train_label) #txt_loss_fn(txt_outs, train_label)+ img_loss_fn(img_outs, train_label) 
  #       batch_loss.backward()
  #       mini_optim.step()
  #   model.align.requires_grad_(True)
  #   print()
    


  total_loss_val = 0

  with torch.no_grad():
      #Validation
      for val_image, val_text, val_label in val_dataloader:

          val_label = val_label.to(device)
          val_image = val_image.to(device)
          val_text = processor(val_text, padding = True, max_length = 100, truncation = True, return_tensors = 'pt')
          mask = val_text['attention_mask'].to(device)
          input_id = val_text['input_ids'].to(device)

          output, img_outs, txt_outs = model(pixel_values = val_image, input_ids = input_id, attention_mask = mask)
          batch_loss = 5*output['loss'] + 5*txt_loss_fn(txt_outs, train_label) + .25*img_loss_fn(img_outs, train_label)
          total_loss_val += batch_loss.item()
          
          val_img_acc(img_outs, val_label)
          val_txt_acc(txt_outs, val_label)
              
      
  avg_train_loss = total_loss_train/len(train_df)

  avg_val_loss = total_loss_val/len(val_df)


  print("Epoch [{}/{}], Train Loss: {:.4f}, acc img: {:.4f}, txt : {:.4f}".format(epoch_num+1, EPOCHS, avg_train_loss*BATCH_SIZE, train_img_acc.compute(), train_txt_acc.compute()))
  print("Epoch [{}/{}], Val Loss: {:.4f}, acc img: {:.4f}, txt : {:.4f}".format(epoch_num+1, EPOCHS, avg_val_loss*BATCH_SIZE, val_img_acc.compute(), val_txt_acc.compute()))

  train_img_acc.reset()
  train_txt_acc.reset()
  val_img_acc.reset()
  val_txt_acc.reset()

  torch.save(model.state_dict(), './' + 'checkpoint' + '.pt' )

torch.save(model.state_dict(), './' + 'finetuned' + '.pt' )


100%|██████████| 39/39 [00:21<00:00,  1.85it/s]


Epoch [1/100], Train Loss: 29.4284, acc img: 0.2944, txt : 0.7384
Epoch [1/100], Val Loss: 6.8995, acc img: 0.3197, txt : 0.9056


100%|██████████| 39/39 [00:18<00:00,  2.15it/s]


Epoch [2/100], Train Loss: 29.1346, acc img: 0.3311, txt : 0.9202
Epoch [2/100], Val Loss: 6.8577, acc img: 0.3503, txt : 0.9210


100%|██████████| 39/39 [00:18<00:00,  2.15it/s]


Epoch [3/100], Train Loss: 28.9997, acc img: 0.3672, txt : 0.9276
Epoch [3/100], Val Loss: 6.8340, acc img: 0.3965, txt : 0.9287


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [4/100], Train Loss: 28.9223, acc img: 0.5095, txt : 0.9305
Epoch [4/100], Val Loss: 6.8176, acc img: 0.6082, txt : 0.9309


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [5/100], Train Loss: 28.8674, acc img: 0.6292, txt : 0.9324
Epoch [5/100], Val Loss: 6.8069, acc img: 0.6600, txt : 0.9313


100%|██████████| 39/39 [00:17<00:00,  2.18it/s]


Epoch [6/100], Train Loss: 28.8427, acc img: 0.6821, txt : 0.9371
Epoch [6/100], Val Loss: 6.7944, acc img: 0.6974, txt : 0.9488


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [7/100], Train Loss: 28.7921, acc img: 0.7490, txt : 0.9521
Epoch [7/100], Val Loss: 6.7912, acc img: 0.7872, txt : 0.9501


100%|██████████| 39/39 [00:18<00:00,  2.16it/s]


Epoch [8/100], Train Loss: 28.7699, acc img: 0.8205, txt : 0.9527
Epoch [8/100], Val Loss: 6.7906, acc img: 0.8393, txt : 0.9488


100%|██████████| 39/39 [00:18<00:00,  2.16it/s]


Epoch [9/100], Train Loss: 28.7435, acc img: 0.8529, txt : 0.9522
Epoch [9/100], Val Loss: 6.7685, acc img: 0.8525, txt : 0.9488


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [10/100], Train Loss: 28.7201, acc img: 0.8585, txt : 0.9526
Epoch [10/100], Val Loss: 6.7712, acc img: 0.8551, txt : 0.9508


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [11/100], Train Loss: 28.6994, acc img: 0.8610, txt : 0.9530
Epoch [11/100], Val Loss: 6.7676, acc img: 0.8574, txt : 0.9493


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [12/100], Train Loss: 28.6865, acc img: 0.8625, txt : 0.9610
Epoch [12/100], Val Loss: 6.7491, acc img: 0.8602, txt : 0.9579


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [13/100], Train Loss: 28.6483, acc img: 0.8634, txt : 0.9626
Epoch [13/100], Val Loss: 6.7425, acc img: 0.8606, txt : 0.9598


100%|██████████| 39/39 [00:17<00:00,  2.18it/s]


Epoch [14/100], Train Loss: 28.6467, acc img: 0.8630, txt : 0.9628
Epoch [14/100], Val Loss: 6.7380, acc img: 0.8602, txt : 0.9599


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [15/100], Train Loss: 28.5702, acc img: 0.8626, txt : 0.9632
Epoch [15/100], Val Loss: 6.7489, acc img: 0.8584, txt : 0.9583


100%|██████████| 39/39 [00:17<00:00,  2.17it/s]


Epoch [16/100], Train Loss: 28.5546, acc img: 0.8612, txt : 0.9629
Epoch [16/100], Val Loss: 6.7311, acc img: 0.8570, txt : 0.9596


100%|██████████| 39/39 [00:18<00:00,  2.17it/s]


Epoch [17/100], Train Loss: 28.5082, acc img: 0.8586, txt : 0.9631
Epoch [17/100], Val Loss: 6.7255, acc img: 0.8553, txt : 0.9598


100%|██████████| 39/39 [00:18<00:00,  2.16it/s]


Epoch [18/100], Train Loss: 28.4902, acc img: 0.8568, txt : 0.9631
Epoch [18/100], Val Loss: 6.7154, acc img: 0.8534, txt : 0.9596


 21%|██        | 8/39 [00:04<00:16,  1.84it/s]


RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
with open("finetuned.pt", 'rb') as f:
    model.load_state_dict(torch.load(f))

## Evaluation on test set

In [None]:
#criterion = nn.BCELoss(torch.tensor([0.5, 1, 5, 5, 5, 6, 5, 1]).float().to(device))

test_loss = 0
test_acc  = 0

AVERAGING = 'weighted'
PREC = torchmetrics.classification.MultilabelPrecision(8, average = AVERAGING).to(device)#, validate_args = False)
ACC = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)#, validate_args = False)
REC = torchmetrics.classification.MultilabelRecall(8, average = AVERAGING).to(device)#, validate_args = False)
F1_SCORE = torchmetrics.classification.MultilabelF1Score(8, average = AVERAGING).to(device)#, validate_args = False)
F_BETA_SCORE = torchmetrics.classification.MultilabelFBetaScore(beta = 0.8, num_classes = 8, num_labels = 8, average = AVERAGING).to(device)#, validate_args = False)
KAPPA = torchmetrics.classification.MulticlassCohenKappa(8).to(device)#, validate_args = False)
AUC = torchmetrics.classification.MultilabelAUROC(8, average = AVERAGING).to(device)#, validate_args = False)

for train_image, train_text, train_label in tqdm(test_dataloader): 
    with torch.no_grad():
        train_label = train_label.to(device)
        train_image = train_image.to(device)
        mask = train_text['attention_mask'].to(device)
        input_id = train_text['input_ids'].squeeze(1).to(device)
        
        # logits_per_image, logits_per_text
        #out_img, out_txt = model.forward(train_image, input_id, mask, contrastive = True)
        output, predictions, txt_outs = model.forward(pixel_values = train_image, input_ids = input_id, attention_mask = mask)



        train_label = train_label.long()
        PREC(predictions, train_label)
        ACC(predictions, train_label)
        REC(predictions, train_label)
        F1_SCORE(predictions, train_label)
        F_BETA_SCORE(predictions, train_label)
        KAPPA(predictions, train_label)
        AUC(predictions, train_label)


add_prec = PREC.compute()
add_acc = ACC.compute()
add_rec = REC.compute()
add_f1 = F1_SCORE.compute()
add_fbeta = F_BETA_SCORE.compute()
add_kappa = KAPPA.compute()
add_auc = AUC.compute()

avg_test_loss = test_loss/len(test_df)*BATCH_SIZE
avg_test_acc  = test_acc /len(test_df)

print("Acc: {:3f}\nPrec: {:3f}\nRecall: {:.3f}\nF1-score: {:.3f}\nF-Beta-score: {:.3f}\nKappa: {:.3f}\nAUC: {:.3f}".format(add_acc, add_prec,add_rec, add_f1, add_fbeta, add_kappa, add_auc))
torch.cuda.empty_cache()

  0%|          | 0/5 [00:00<?, ?it/s]


TypeError: tuple indices must be integers or slices, not str