In [62]:
import requests
import os
import logging
import gdown
import random

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from torchvision.datasets import VOCSegmentation
import torchmetrics
import torchvision
import albumentations as A

import re
import string
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import cv2
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as T
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from scipy.io import loadmat
from sklearn.manifold import TSNE
from torchmetrics.classification import MulticlassF1Score, JaccardIndex, MulticlassPrecision, MulticlassRecall, MulticlassAveragePrecision
import pandas as pd
from torchinfo import torchinfo

from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, CLIPModel, CLIPTokenizer, CLIPProcessor
from sklearn.model_selection import train_test_split
from tqdm import tqdm


In [63]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
logging.basicConfig(level=logging.ERROR)

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda:0")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
    
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")


There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 3090


In [64]:
ROOT_DIR = '../Datasets/ocular-disease-recognition-odir5k/'

## ALIGN

In [65]:
BATCH_SIZE = 16

In [66]:
CSV_PATH = ROOT_DIR + 'dataset_single_eye.csv'
TEST_CSV = ROOT_DIR + 'TESTING_dataset_single_eye.csv'
IMG_PATH = ROOT_DIR + 'preprocessed_images/'

In [67]:
torchvision.io.read_image(IMG_PATH + '0_left.jpg').shape

torch.Size([3, 512, 512])

In [68]:
train_val_df = pd.read_csv(CSV_PATH)
test_df = pd.read_csv(TEST_CSV)

In [69]:
def preprocess_text(df:pd.DataFrame):
    df['Keywords'] = df['Keywords'].str.lower()
    df['Keywords'] = df['Keywords'].apply(lambda x: " ".join(x.split()))
    df['Keywords'] = df['Keywords'].apply(lambda x: re.sub('[%s]' % re.escape(string.punctuation), '' , x))
    return df
train_val_df = preprocess_text(train_val_df)
test_df = preprocess_text(test_df)

In [70]:
np.max(train_val_df['Keywords'].apply(lambda x: len(x.split())))

10

In [71]:
train_df, val_df = train_test_split(train_val_df, test_size = 0.15)
len(train_df), len(val_df)

(4877, 861)

In [72]:
IMG_SIZE = (224, 224)

rescale_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(IMG_SIZE, antialias = True),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])

In [73]:
processor = RobertaTokenizer.from_pretrained('roberta-base')

In [74]:
class ODIRDatasetMM(Dataset) :
    def __init__(self, df, IMG_FOLDER, tokenizer = processor) :
        '''
        id : list of samples ids as string
        '''
        self.text = [tokenizer(text = x, padding = 'max_length', max_length = 25, truncation = True, return_tensors = 'pt') for x in df['Keywords']]
        self.eye = df['eye']
        self.labels = torch.tensor(df[['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']].to_numpy()).float()
        self.img_dir = [IMG_PATH + x for x in df['Image']]

        self.images = [rescale_transform(torchvision.io.read_image(x).float()/255.0) for x in self.img_dir]
        #self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        return self.images[idx], self.text[idx], self.labels[idx]

In [75]:
train_dataset = ODIRDatasetMM(train_df, IMG_PATH)
val_dataset   = ODIRDatasetMM(val_df, IMG_PATH)
test_dataset  = ODIRDatasetMM(test_df, IMG_PATH)

In [76]:
train_dataloader = DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size= BATCH_SIZE)

In [45]:
# contrastive learning on training data finetuning

In [77]:
from transformers import ConvNextV2Model, RobertaModel

In [78]:
# prepare two models: BERT vs ConvNext, try to compute contrastive losses
class ContrastiveLearning(nn.Module):
    def __init__(self, drop_prob = 0.4):
        super().__init__()
        self.img_model = ConvNextV2Model.from_pretrained("facebook/convnextv2-tiny-1k-224")      #output 768 features
        self.txt_model = RobertaModel.from_pretrained("roberta-base")                            #output 768 features
        
        self.fc1 = nn.Linear(768, 256)
        self.fc2 = nn.Linear(256, 768)
        self.dropout = nn.Dropout(drop_prob)
        self.head = nn.Linear(768, 8)
    def forward(self, img_input, input_ids = None, attn_mask = None, contrastive = False):
        if(contrastive):
            # pretraining
            out_txt = self.txt_model(input_ids, attn_mask)
            out_img = self.img_model(img_input)

            return out_img, out_txt
        else:
            out = self.img_model(img_input, return_dict = False)[1]
            resi = out
            out = F.relu(self.fc1(out))
            out = F.relu(self.fc2(out))
            out = out + resi
            out = self.head(out)
            return F.sigmoid(out)

In [84]:
model = ContrastiveLearning().to(device)

Some weights of the model checkpoint at facebook/convnextv2-tiny-1k-224 were not used when initializing ConvNextV2Model: ['convnextv2.encoder.stages.0.layers.1.grn.weight', 'convnextv2.encoder.stages.2.layers.1.grn.bias', 'convnextv2.encoder.stages.3.layers.1.grn.bias', 'convnextv2.encoder.stages.3.layers.2.grn.bias', 'convnextv2.encoder.stages.2.layers.0.grn.weight', 'convnextv2.encoder.stages.0.layers.2.grn.weight', 'convnextv2.encoder.stages.1.layers.2.grn.weight', 'convnextv2.encoder.stages.0.layers.0.grn.bias', 'convnextv2.encoder.stages.1.layers.1.grn.bias', 'convnextv2.encoder.stages.2.layers.2.grn.bias', 'classifier.weight', 'convnextv2.encoder.stages.1.layers.2.grn.bias', 'convnextv2.encoder.stages.0.layers.1.grn.bias', 'convnextv2.encoder.stages.2.layers.7.grn.weight', 'convnextv2.encoder.stages.2.layers.1.grn.weight', 'convnextv2.encoder.stages.2.layers.0.grn.bias', 'convnextv2.encoder.stages.3.layers.0.grn.weight', 'convnextv2.encoder.stages.1.layers.0.grn.weight', 'convnex

In [85]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

loss_img = nn.CrossEntropyLoss()
loss_text = nn.CrossEntropyLoss()

EPOCHS = 5

for epoch_num in range(EPOCHS):

      total_acc_train = 0
      total_loss_train = 0

      for train_image, train_text, train_label in tqdm(train_dataloader):
          optimizer.zero_grad()

          train_image = train_image.to(device)
          mask = train_text['attention_mask'].to(device)
          input_id = train_text['input_ids'].squeeze(1).to(device)
          
          # logits_per_image, logits_per_text
          out_img, out_txt = model.forward(train_image, input_id, mask, contrastive = True)

          #print(out_txt['pooler_output'].shape)

          #ground_truth = torch.arange(len(train_image),dtype=torch.long,device=device)
          
          batch_loss = (loss_img(out_txt['pooler_output'], out_img['pooler_output']) + loss_text(out_img['pooler_output'], out_txt['pooler_output']))
          batch_loss.backward()
          optimizer.step()
          total_loss_train += batch_loss.item()
          
        #   acc = (output['logits'].argmax(dim=1) == train_label).sum().item()
        #   total_acc_train += acc
      
      total_acc_val = 0
      total_loss_val = 0

      with torch.no_grad():

          for val_image, val_text, val_label in val_dataloader:

              val_label = val_label.to(device)
              val_image = val_image.to(device)
              mask = val_text['attention_mask'].to(device)
              input_id = val_text['input_ids'].squeeze(1).to(device)

              out_img, out_txt = model.forward(val_image, input_id, mask, contrastive = True)
              
              #ground_truth = torch.arange(len(val_image),dtype=torch.long,device=device)
              batch_loss = (loss_img(out_txt['pooler_output'], out_img['pooler_output']) + loss_text(out_img['pooler_output'], out_txt['pooler_output']))
              total_loss_val += batch_loss.item()
              
      
      avg_train_loss = total_loss_train/len(train_df)
    #   train_accuracy = total_acc_train/len(train_df)

      avg_val_loss = total_loss_val/len(val_df)
    #   val_accuracy = total_acc_val/len(dev_df)

      print(f"Epoch [{epoch_num+1}/{EPOCHS}], "f"Train Loss: {avg_train_loss*BATCH_SIZE:.4f}")
      print(f"Epoch [{epoch_num+1}/{EPOCHS}], "f"Val Loss: {avg_val_loss*BATCH_SIZE:.4f}")
      print('-'*60)

      torch.save(model.state_dict(), './' + str(epoch_num+21) + '.pt' )


100%|██████████| 305/305 [01:17<00:00,  3.93it/s]


Epoch [1/5], Train Loss: -7368.4465
Epoch [1/5], Val Loss: -9049.8051
------------------------------------------------------------


 97%|█████████▋| 296/305 [01:15<00:02,  4.00it/s]

In [None]:
with open("25.pt", 'rb') as f:
    model.load_state_dict(torch.load(f))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

criterion = nn.BCELoss()
test_loss = 0
test_acc  = 0
AVERAGING = 'micro'
acc_train = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)#, validate_args = False)
acc_val   = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)

EPOCHS = 5

for epoch_num in range(EPOCHS):
      total_loss_train = 0

      for train_image, train_text, train_label in tqdm(train_dataloader):
          optimizer.zero_grad()

          train_label = train_label.to(device)
          train_image = train_image.to(device)
          mask = train_text['attention_mask'].to(device)
          input_id = train_text['input_ids'].squeeze(1).to(device)
          
          # logits_per_image, logits_per_text
          out_img = model.forward(train_image, contrastive = False)
          #print(out_img.shape, train_label.shape)
          batch_loss = criterion(out_img, train_label)
          batch_loss.backward()
          optimizer.step()
          total_loss_train += batch_loss.item()
          
          acc_train(out_img, train_label)
        #   acc = (output['logits'].argmax(dim=1) == train_label).sum().item()
        #   total_acc_train += acc
      
      total_acc_val = 0
      total_loss_val = 0

      with torch.no_grad():

          for val_image, val_text, val_label in val_dataloader:

              val_label = val_label.to(device)
              val_image = val_image.to(device)
              mask = val_text['attention_mask'].to(device)
              input_id = val_text['input_ids'].squeeze(1).to(device)

              out_img = model.forward(val_image, contrastive = False)
              

              batch_loss = criterion(out_img, val_label)
              total_loss_val += batch_loss.item()
              acc_val(out_img,val_label)
              
      
      avg_train_loss = total_loss_train/len(train_df)
    #   train_accuracy = total_acc_train/len(train_df)

      avg_val_loss = total_loss_val/len(val_df)
    #   val_accuracy = total_acc_val/len(dev_df)

      print(f"Epoch [{epoch_num+1}/{EPOCHS}], "f"Train Loss: {avg_train_loss*BATCH_SIZE:.4f}, "f"Train Accuracy: {acc_train.compute():.4f}")
      print(f"Epoch [{epoch_num+1}/{EPOCHS}], "f"Val Loss: {avg_val_loss*BATCH_SIZE:.4f}, "f"Val Accuracy: {acc_val.compute():.4f}")
      print('-'*60)
      
      acc_train.reset()
      acc_val.reset()

      torch.save(model.state_dict(), './' + str(epoch_num)+'finetuning' + '.pt' )


100%|██████████| 305/305 [00:51<00:00,  5.91it/s]


Epoch [1/10], Train Loss: 0.0191, Train Accuracy: 0.8711
Epoch [1/10], Val Loss: 0.0189, Val Accuracy: 0.8750
------------------------------------------------------------


100%|██████████| 305/305 [00:51<00:00,  5.96it/s]


Epoch [2/10], Train Loss: 0.0188, Train Accuracy: 0.8733
Epoch [2/10], Val Loss: 0.0188, Val Accuracy: 0.8750
------------------------------------------------------------


100%|██████████| 305/305 [00:51<00:00,  5.89it/s]


Epoch [3/10], Train Loss: 0.0186, Train Accuracy: 0.8737
Epoch [3/10], Val Loss: 0.0186, Val Accuracy: 0.8773
------------------------------------------------------------


100%|██████████| 305/305 [00:51<00:00,  5.90it/s]


Epoch [4/10], Train Loss: 0.0185, Train Accuracy: 0.8752
Epoch [4/10], Val Loss: 0.0186, Val Accuracy: 0.8775
------------------------------------------------------------


100%|██████████| 305/305 [00:51<00:00,  5.93it/s]


Epoch [5/10], Train Loss: 0.0184, Train Accuracy: 0.8749
Epoch [5/10], Val Loss: 0.0185, Val Accuracy: 0.8750
------------------------------------------------------------


 89%|████████▊ | 270/305 [00:45<00:05,  5.93it/s]


KeyboardInterrupt: 

## Evaluation on test set

In [None]:
test_loss = 0
test_acc  = 0
PREC = torchmetrics.classification.MultilabelPrecision(8, average = 'micro').to(device)#, validate_args = False)
ACC = torchmetrics.classification.MultilabelAccuracy(8, average = 'micro').to(device)#, validate_args = False)
REC = torchmetrics.classification.MultilabelRecall(8, average = 'micro').to(device)#, validate_args = False)
F1_SCORE = torchmetrics.classification.MultilabelF1Score(8, average = 'micro').to(device)#, validate_args = False)
F_BETA_SCORE = torchmetrics.classification.MultilabelFBetaScore(beta = 0.8, num_classes = 8, num_labels = 8, average = 'micro').to(device)#, validate_args = False)
KAPPA = torchmetrics.classification.MulticlassCohenKappa(8).to(device)#, validate_args = False)
AUC = torchmetrics.classification.MultilabelAUROC(8, average = 'micro').to(device)#, validate_args = False)

for train_image, train_text, train_label in tqdm(test_dataloader): 
    train_image = train_image.to(device)
    train_label = train_label.to(device)
    with torch.no_grad():
        scores = model(train_image, contrastive = False)
        scores = torch.sigmoid(scores)
    loss = criterion(scores, train_label.float())
    test_loss+= loss.item()
    predicted = torch.round(scores).to(device)
    test_acc+= (torch.sum(predicted == train_label)/(BATCH_SIZE*8))

    train_label = train_label.int()
    PREC(predicted, train_label)
    ACC(predicted, train_label)
    REC(predicted, train_label)
    F1_SCORE(predicted, train_label)
    F_BETA_SCORE(predicted, train_label)
    KAPPA(predicted, train_label)
    AUC(predicted, train_label)


add_prec = PREC.compute()
add_acc = ACC.compute()
add_rec = REC.compute()
add_f1 = F1_SCORE.compute()
add_fbeta = F_BETA_SCORE.compute()
add_kappa = KAPPA.compute()
add_auc = AUC.compute()

avg_test_loss = test_loss/len(test_df)*BATCH_SIZE
avg_test_acc  = test_acc /len(test_df)

print("Acc: {:3f}\nPrec: {:3f}\nRecall: {:.3f}\nF1-score: {:.3f}\nF-Beta-score: {:.3f}\nKappa: {:.3f}\nAUC: {:.3f}".format(add_acc, add_prec,add_rec, add_f1, add_fbeta, add_kappa, add_auc))
torch.cuda.empty_cache()

100%|██████████| 40/40 [00:03<00:00, 12.53it/s]

Acc: 0.125000
Prec: 0.125000
Recall: 1.000
F1-score: 0.222
F-Beta-score: 0.190
Kappa: 0.000
AUC: 0.500



