In [1]:
import requests
import os
import logging
import gdown
import random

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
from torchvision.datasets import VOCSegmentation
import torchmetrics
import torchvision
import albumentations as A

import re
import string
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import cv2
from PIL import Image
from tqdm import tqdm
import torchvision.transforms as T
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from scipy.io import loadmat
from sklearn.manifold import TSNE
from torchmetrics.classification import MulticlassF1Score, JaccardIndex, MulticlassPrecision, MulticlassRecall, MulticlassAveragePrecision
import pandas as pd
from torchinfo import torchinfo

from transformers import ConvNextV2Model, BertModel, BertTokenizer, ViTModel, ViTConfig
from transformers import AutoTokenizer, AutoModel, RobertaTokenizer, CLIPModel, CLIPTokenizer, CLIPProcessor
from transformers import DeiTConfig, DeiTFeatureExtractor, DeiTImageProcessor, DeiTModel
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import timm
from torchmetrics.functional import pairwise_cosine_similarity

In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
logging.basicConfig(level=logging.ERROR)

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda:0")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
    
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")


There are 1 GPU(s) available.
We will use the GPU: NVIDIA GeForce RTX 3090


In [3]:
ROOT_DIR = '../Datasets/ocular-disease-recognition-odir5k/'

## ALIGN

In [29]:
BATCH_SIZE = 32

In [30]:
CSV_PATH = ROOT_DIR + 'dataset_single_eye.csv'
TEST_CSV = ROOT_DIR + 'TESTING_dataset_single_eye.csv'
IMG_PATH = ROOT_DIR + 'preprocessed_images/'

In [31]:
torchvision.io.read_image(IMG_PATH + '0_left.jpg').shape

torch.Size([3, 512, 512])

In [32]:
train_val_df = pd.read_csv(CSV_PATH)
test_df = pd.read_csv(TEST_CSV)

In [33]:
def preprocess_text(df:pd.DataFrame):
    df['Keywords'] = df['Keywords'].str.lower()
    df['Keywords'] = df['Keywords'].apply(lambda x: " ".join(x.split()))
    df['Keywords'] = df['Keywords'].apply(lambda x: re.sub('[%s]' % re.escape(string.punctuation), '' , x))
    return df
train_val_df = preprocess_text(train_val_df)
test_df = preprocess_text(test_df)

In [34]:
np.max(train_val_df['Keywords'].apply(lambda x: len(x.split())))

10

In [35]:
train_df, val_df = train_test_split(train_val_df, test_size = 0.15, random_state= 123456)
len(train_df), len(val_df)

(4877, 861)

In [36]:
IMG_SIZE = (224, 224)

rescale_transform = torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop(IMG_SIZE),
    #torchvision.transforms.Resize(IMG_SIZE, antialias = False, interpolation = torchvision.transforms.InterpolationMode.NEAREST),
    # torchvision.transforms.Normalize(
    #     timm.data.constants.IMAGENET_DEFAULT_MEAN,
    #     timm.data.constants.IMAGENET_DEFAULT_STD
    # )
])

augmentation = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(p = 0.5),
    torchvision.transforms.RandomVerticalFlip(p= 0.5),
    #torchvision.transforms.RandomRotation(90)
])

In [51]:
processor = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')

Using cache found in C:\Users\krish/.cache\torch\hub\huggingface_pytorch-transformers_main


In [52]:
class ODIRDatasetMM(Dataset) :
    def __init__(self, df, IMG_FOLDER, tokenizer = processor, feature_extractor = rescale_transform, augmentation = None) :
        '''
        id : list of samples ids as string
        '''
        self.text = [tokenizer(text = x, padding = 'max_length', max_length = 45, truncation = True, return_tensors = 'pt') for x in df['Keywords']]
        self.eye = df['eye']
        self.labels = torch.tensor(df[['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']].to_numpy()).float()
        self.img_dir = [IMG_PATH + x for x in df['Image']]

        self.augmentation = augmentation

        self.images = [feature_extractor(torchvision.io.read_image(x).float()/255.0) for x in self.img_dir]
        #self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        batch_imgs = self.images[idx]
        if(self.augmentation is not None):
            batch_imgs = self.augmentation(batch_imgs)
        return batch_imgs, self.text[idx], self.labels[idx]

In [53]:
train_dataset = ODIRDatasetMM(train_df, IMG_PATH, augmentation = augmentation)
val_dataset   = ODIRDatasetMM(val_df, IMG_PATH)
test_dataset  = ODIRDatasetMM(test_df, IMG_PATH)

In [54]:
train_dataloader = DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size= BATCH_SIZE)

In [41]:
# contrastive learning on training data finetuning

In [43]:
# prepare two models: BERT vs ConvNext, try to compute contrastive losses
class ContrastiveLearning(nn.Module):
    def __init__(self, drop_prob = 0.4):
        super().__init__()
        self.img_model = DeiTModel.from_pretrained("facebook/deit-base-patch16-224")                                      
        # self.img_model.classifier = nn.Sequential(
        #     nn.Linear(1536, 768)
        # )

        self.txt_model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'roberta-base')         #output 768 features
        
        # image model classification head
        self.fc1 = nn.Linear(768, 768)
        self.fc2 = nn.Linear(768, 768)
        self.img_head = nn.Linear(768, 8)
        
        # text model classification ehad
        self.fc3 = nn.Linear(768, 768)
        self.fc4 = nn.Linear(768, 768)
        self.text_head = nn.Linear(768, 8)

        self.dropout_layer = nn.Dropout(drop_prob)

    def forward(self, img_input = None, input_ids = None, attn_mask = None, contrastive = False):#, text_class = False):
        if(contrastive):
            # pretraining
            out_txt = self.txt_model(input_ids, attn_mask)
            out_img = self.img_model(img_input)

            out_txt_ret = F.normalize(out_txt['last_hidden_state'], p = 2.0, dim = 1)
            out_img_ret = F.normalize(out_img['last_hidden_state'], p = 2.0, dim = 1)

            #img path
            resi_img = out_img['pooler_output']
            img_route_out = self.dropout_layer(out_img['pooler_output'])
            img_route_out = F.relu(self.fc1(img_route_out))
            img_route_out = self.dropout_layer(img_route_out)
            img_route_out = F.relu(self.fc2(img_route_out))
            img_route_out = img_route_out + resi_img

            img_route_out = F.sigmoid(self.img_head(img_route_out))

            #text path
            resi_txt = out_txt['pooler_output']
            txt_route_out = self.dropout_layer(out_txt['pooler_output'])
            txt_route_out = F.relu(self.fc3(txt_route_out))
            txt_route_out = self.dropout_layer(txt_route_out)
            txt_route_out = F.relu(self.fc4(txt_route_out))
            txt_route_out = txt_route_out + resi_txt

            txt_route_out = F.sigmoid(self.text_head(txt_route_out))
            return out_img_ret, out_txt_ret, img_route_out, txt_route_out

            #return out_img, out_txt
        else:
            #if(not text_class):
            out = self.img_model(img_input)['pooler_output']
            resi_img = out
            img_route_out = self.dropout_layer(out)
            img_route_out = F.relu(self.fc1(img_route_out))
            img_route_out = self.dropout_layer(img_route_out)
            img_route_out = F.relu(self.fc2(img_route_out))
            img_route_out = img_route_out + resi_img

            out = F.sigmoid(self.img_head(out))
            return out
            # else:
            #     out = self.txt_model(input_ids, attn_mask)['pooler_output']
            #     resi = out
            #     out = self.dropout_layer(out)
            #     out = F.relu(self.fc2(out))
            #     out = resi + out

            #     out = F.sigmoid(self.text_head(out))
            #     return out

In [44]:
def contrastive_loss(ten1, ten2, temperature = nn.Parameter(torch.tensor(.25).to(device))):    #...
    #steps = hadamard product
    # trivial for loop 
    sim = torch.einsum('i d, j d -> i j', ten1, ten2) * temperature.exp()
    labels = torch.arange(ten1.size(0), device = device)
    loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
    return loss.mean()

In [45]:
import torch
import torch.nn as nn


class SupConLoss(nn.Module):
    '''Adapted from HobbitLong (www.github.com/HobbitLong/SupContrast)'''
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [46]:
model = ContrastiveLearning().to(device)

You are using a model of type vit to instantiate a model of type deit. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at facebook/deit-base-patch16-224 were not used when initializing DeiTModel: ['vit.encoder.layer.5.attention.attention.query.weight', 'vit.encoder.layer.5.attention.attention.key.bias', 'vit.encoder.layer.2.output.dense.bias', 'vit.encoder.layer.1.attention.attention.key.bias', 'vit.encoder.layer.1.attention.attention.value.weight', 'vit.encoder.layer.7.layernorm_before.weight', 'vit.encoder.layer.3.attention.output.dense.bias', 'vit.encoder.layer.3.layernorm_before.bias', 'vit.encoder.layer.10.attention.output.dense.weight', 'vit.encoder.layer.2.layernorm_after.weight', 'vit.encoder.layer.11.attention.attention.key.bias', 'vit.encoder.layer.8.output.dense.weight', 'vit.encoder.layer.9.layernorm_before.bias', 'vit.encoder.layer.0.attention.output.dense.weight', 'vit.encoder.layer.10.layernorm_before.bias

In [58]:
from transformers import AlignModel, AlignProcessor, AlignConfig

m = AlignModel(AlignConfig())

In [59]:
m

AlignModel(
  (text_model): AlignTextModel(
    (embeddings): AlignTextEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): AlignTextEncoder(
      (layer): ModuleList(
        (0-11): 12 x AlignTextLayer(
          (attention): AlignTextAttention(
            (self): AlignTextSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): AlignTextSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerN

In [None]:
p

In [48]:
torch.cuda.empty_cache()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
weights = torch.tensor([.7,1.2,1.5,1.5,1.5,1.7, 1.5, 1.2]).to(device)
criterion_text = nn.BCELoss(weights)
criterion_image = nn.BCELoss(weights)
cont_loss = contrastive_loss
AVERAGING = 'micro'
acc_train_img = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)#, validate_args = False)
acc_train_text = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)#, validate_args = False)
acc_val_img   = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)
acc_val_text   = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)
myloss_img = SupConLoss()

train_text = False
EPOCHS = 15

for epoch_num in range(EPOCHS):

      total_acc_train = 0
      total_loss_train = 0

      for train_image, train_text, train_label in tqdm(train_dataloader):
          optimizer.zero_grad()

          train_label = train_label.to(device)
          train_image = train_image.to(device)
          mask = train_text['attention_mask'].to(device)
          input_id = train_text['input_ids'].squeeze(1).to(device)
          
          # logits_per_image, logits_per_text
          out_img, out_txt, predictions_img, predictions_text = model.forward(train_image, input_id, mask, contrastive = True)
          #predictions_img =  model.forward(train_image, input_id, mask, contrastive = False)
          acc_train_img(predictions_img, train_label)
          acc_train_text(predictions_text, train_label)

          closs_masks = torch.floor(pairwise_cosine_similarity(train_label, zero_diagonal= False))


          closs = cont_loss(torch.mean(out_img, axis = 1), torch.mean(out_txt, axis = 1))
          text_loss = criterion_text(predictions_text, train_label)
          img_loss = criterion_image(predictions_img, train_label)
          # if(train_text):
          batch_loss = closs*2. + text_loss*2. + img_loss*1.
          # else:
          #   batch_loss = myloss_img(out_img) + img_loss
          batch_loss = closs*2. + text_loss*2. + img_loss*.5
          batch_loss.backward()
          optimizer.step()
          total_loss_train += batch_loss.item()
      
      total_acc_val = 0
      total_loss_val = 0

      with torch.no_grad():

          for val_image, val_text, val_label in val_dataloader:

              val_label = val_label.to(device)
              val_image = val_image.to(device)
              mask = val_text['attention_mask'].to(device)
              input_id = val_text['input_ids'].squeeze(1).to(device)

              out_img, out_txt, predictions_img, predictions_text = model.forward(val_image, input_id, mask, contrastive = True)
              #predictions_img =  model.forward(val_image, input_id, mask, contrastive = False)
              acc_val_img(predictions_img, val_label)
              acc_val_text(predictions_text, val_label)

              closs_masks = torch.floor(pairwise_cosine_similarity(val_label, zero_diagonal= False))


              closs = cont_loss(torch.mean(out_img, axis = 1), torch.mean(out_txt, axis = 1))
              text_loss = criterion_text(predictions_text, val_label)
              img_loss = criterion_image(predictions_img, val_label)
              # if(train_text):
              batch_loss = closs*2. + text_loss*2. + img_loss*1.
              # else:
              #   batch_loss = closs
              batch_loss = closs*2. + text_loss*2. + img_loss*.5
              total_loss_val += batch_loss.item()

             # acc_val(predictions, val_label)
              
      
      avg_train_loss = total_loss_train/len(train_df)
    #   train_accuracy = total_acc_train/len(train_df)

      avg_val_loss = total_loss_val/len(val_df)
    #   val_accuracy = total_acc_val/len(dev_df)

      # if(acc_train_text.compute() >= 0.99):
      #   print("Fixing Text model component!")
      #   train_text = False
      #   model.txt_model.requires_grad_(False)
      #   model.text_head.requires_grad_(False)
      # else:
      #   print("Unfreezing text model component")
      #   train_text = True
      #   model.txt_model.requires_grad_(True)
      #   model.text_head.requires_grad_(True)

      print("Epoch [{}/{}], Train Loss: {:.4f}, Train Accuracy: {:.4f} img, {:.4f} txt".format(epoch_num+1, EPOCHS, avg_train_loss*BATCH_SIZE, acc_train_img.compute(), acc_train_text.compute()))
      print("Epoch [{}/{}], Val Loss: {:.4f}, Val Accuracy: {:.4f} img, {:.4f} txt".format(epoch_num+1, EPOCHS, avg_val_loss*BATCH_SIZE, acc_val_img.compute(), acc_val_text.compute()))

      acc_train_img.reset()
      acc_train_text.reset()
      acc_val_img.reset
      acc_val_text.reset()
      torch.save(model.state_dict(), './' + 'checkpoint' + '.pt' )

torch.save(model.state_dict(), './' + 'finetuned' + '.pt' )


100%|██████████| 153/153 [01:07<00:00,  2.27it/s]


Epoch [1/15], Train Loss: 7.1651, Train Accuracy: 0.8723 img, 0.9977 txt
Epoch [1/15], Val Loss: 7.1367, Val Accuracy: 0.8749 img, 0.9997 txt


100%|██████████| 153/153 [01:07<00:00,  2.28it/s]


Epoch [2/15], Train Loss: 7.1242, Train Accuracy: 0.8735 img, 0.9998 txt
Epoch [2/15], Val Loss: 7.1638, Val Accuracy: 0.8747 img, 0.9984 txt


100%|██████████| 153/153 [01:08<00:00,  2.24it/s]


Epoch [3/15], Train Loss: 7.1119, Train Accuracy: 0.8756 img, 0.9990 txt
Epoch [3/15], Val Loss: 7.1027, Val Accuracy: 0.8748 img, 0.9997 txt


100%|██████████| 153/153 [01:08<00:00,  2.24it/s]


Epoch [4/15], Train Loss: 7.0751, Train Accuracy: 0.8745 img, 0.9992 txt
Epoch [4/15], Val Loss: 7.0631, Val Accuracy: 0.8741 img, 0.9997 txt


  6%|▌         | 9/153 [00:04<01:06,  2.15it/s]


KeyboardInterrupt: 

In [None]:
with open("finetuned.pt", 'rb') as f:
    model.load_state_dict(torch.load(f))

## Evaluation on test set

In [None]:
criterion = nn.BCELoss(torch.tensor([0.5, 1, 5, 5, 5, 6, 5, 1]).float().to(device))

test_loss = 0
test_acc  = 0

AVERAGING = 'weighted'
PREC = torchmetrics.classification.MultilabelPrecision(8, average = AVERAGING).to(device)#, validate_args = False)
ACC = torchmetrics.classification.MultilabelAccuracy(8, average = AVERAGING).to(device)#, validate_args = False)
REC = torchmetrics.classification.MultilabelRecall(8, average = AVERAGING).to(device)#, validate_args = False)
F1_SCORE = torchmetrics.classification.MultilabelF1Score(8, average = AVERAGING).to(device)#, validate_args = False)
F_BETA_SCORE = torchmetrics.classification.MultilabelFBetaScore(beta = 0.8, num_classes = 8, num_labels = 8, average = AVERAGING).to(device)#, validate_args = False)
KAPPA = torchmetrics.classification.MulticlassCohenKappa(8).to(device)#, validate_args = False)
AUC = torchmetrics.classification.MultilabelAUROC(8, average = AVERAGING).to(device)#, validate_args = False)

for train_image, train_text, train_label in tqdm(test_dataloader): 
    with torch.no_grad():
        train_label = train_label.to(device)
        train_image = train_image.to(device)
        mask = train_text['attention_mask'].to(device)
        input_id = train_text['input_ids'].squeeze(1).to(device)
        
        # logits_per_image, logits_per_text
        #out_img, out_txt = model.forward(train_image, input_id, mask, contrastive = True)
        predictions = model.forward(train_image, contrastive = False)



        train_label = train_label.long()
        PREC(predictions, train_label)
        ACC(predictions, train_label)
        REC(predictions, train_label)
        F1_SCORE(predictions, train_label)
        F_BETA_SCORE(predictions, train_label)
        KAPPA(predictions, train_label)
        AUC(predictions, train_label)


add_prec = PREC.compute()
add_acc = ACC.compute()
add_rec = REC.compute()
add_f1 = F1_SCORE.compute()
add_fbeta = F_BETA_SCORE.compute()
add_kappa = KAPPA.compute()
add_auc = AUC.compute()

avg_test_loss = test_loss/len(test_df)*BATCH_SIZE
avg_test_acc  = test_acc /len(test_df)

print("Acc: {:3f}\nPrec: {:3f}\nRecall: {:.3f}\nF1-score: {:.3f}\nF-Beta-score: {:.3f}\nKappa: {:.3f}\nAUC: {:.3f}".format(add_acc, add_prec,add_rec, add_f1, add_fbeta, add_kappa, add_auc))
torch.cuda.empty_cache()

100%|██████████| 13/13 [00:02<00:00,  4.74it/s]


Acc: 0.794288
Prec: 0.626230
Recall: 0.635
F1-score: 0.627
F-Beta-score: 0.626
Kappa: 0.000
AUC: 0.816
