<a href="https://colab.research.google.com/github/nirvanesque/examples/blob/master/Tutorials_FLAIR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Clone and install FLAIR requirements
!git clone https://github.com/jusiro/FLAIR.git
!pip install -r ./FLAIR/requirements.txt

In [None]:
# Set relative path
import sys
sys.path.append('FLAIR')

In [None]:
# Imports
import numpy as np
import torch

from PIL import Image

from flair import FLAIRModel

In [None]:
# Load model from pre-trained weights
model = FLAIRModel(from_checkpoint=True)

Pretrained weights: IMAGENET1K_V1


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]


 Download model to: ./flair/modeling/flair_pretrained_weights/flair_resnet.pth
load model weight from: ./flair/modeling/flair_pretrained_weights/flair_resnet.pth


In [None]:
# Macular hole prediction - Note that this category has not been used for training!

# Load image and set target categories
image = np.array(Image.open("./FLAIR/documents/sample_macular_hole.png"))
text = ["normal", "healthy", "macular edema", "diabetic retinopathy", "glaucoma", "macular hole",
        "lesion", "lesion in the macula"]

# Forward FLAIR model to compute similarities
probs, logits = model(image, text)

print("Image-Text similarities:")
print(logits.round(3)) # [[-0.32  -2.782  3.164  4.388  5.919  6.639  6.579 10.478]]
print("Probabilities:")
print(probs.round(3))  # [[0.      0.     0.001  0.002  0.01   0.02   0.019  0.948]]

Image-Text similarities:
[[-0.323 -2.784  3.164  4.387  5.919  6.64   6.581 10.477]]
Probabilities:
[[0.    0.    0.001 0.002 0.01  0.02  0.019 0.948]]


In [None]:
# Normal sample prediction

# Load image and set target categories
image = np.array(Image.open("./FLAIR/documents/normal_sample.png"))
text = ["normal", "healthy", "macular edema", "diabetic retinopathy", "glaucoma", "macular hole",
        "lesion", "lesion in the macula"]

# Forward FLAIR model to compute similarities
probs, logits = model(image, text)

print("Image-Text similarities:")
print(logits.round(3)) # [[7.424  4.969 -1.247 -1.416 -1.022  0.113  0.693 -0.734]]
print("Probabilities:")
print(probs.round(3))  # [[0.919  0.079  0.     0.     0.     0.001  0.001  0.   ]]

Image-Text similarities:
[[ 7.424  4.969 -1.247 -1.416 -1.022  0.113  0.693 -0.734]]
Probabilities:
[[0.919 0.079 0.    0.    0.    0.001 0.001 0.   ]]


In [None]:
# Using expert knowledge prompts
model.eval()

# Load image and set target categories
image = np.array(Image.open("./FLAIR/documents/severe_nonprol_dr.jpg"))
categories = ["no dibaetic retinopathy", "mild diabetic retinopathy", "moderate diabetic retinopathy", "severe diabetic retinopathy", "proliferative diabetic retinopathy"]

with torch.no_grad():
  # Compute expert knwoledge prompts and forward text encoder
  text_embeds_dict, text_embeds = model.compute_text_embeddings(categories, domain_knowledge=True) # Change to domain_knowledge=False to text w/o expert knowledge prompts.

  # Preprocess image and forward vision encoder
  image = model.preprocess_image(image)
  img_embeds = model.vision_model(image)

  # Compute similarity matrix and logits
  logits = model.compute_logits(img_embeds, text_embeds)

  # Compute probabilities
  probs = logits.softmax(dim=-1)

logits = logits.cpu().numpy()
probs = probs.cpu().numpy()

print("Image-Text similarities:")
print(logits.round(3)) # [[-4.153 -0.084  3.818  5.541  4.954]]
print("Probabilities:")
print(probs.round(3))  # [[ 0.     0.002  0.103  0.575  0.32 ]]

['no dibaetic retinopathy']
['only few microaneurysms', 'mild diabetic retinopathy']
['many exudates near the macula', 'many haemorrhages near the macula', 'retinal thickening near the macula', 'hard exudates', 'cotton wool spots', 'few severe haemorrhages', 'moderate diabetic retinopathy']
['venous beading', 'many severe haemorrhages', 'intraretinal microvascular abnormality', 'severe diabetic retinopathy']
['preretinal or vitreous haemorrhage', 'neovascularization', 'proliferative diabetic retinopathy']
Image-Text similarities:
[[-4.153 -0.084  3.818  5.541  4.954]]
Probabilities:
[[0.    0.002 0.103 0.575 0.32 ]]


In [None]:
# Text - to - Text (hierarchical knowledge)
model.eval()

categories = ["mild diabetic retinopathy", "severe diabetic retinopathy", "proliferative diabetic retinopathy",
              "diabetic macular edema", "few microaneurysms", "many haemorrhages", "neovascularization",
              "exudates in the fovea", "venous beading"]
names = ["mildDR", "sevDR", "prolDR", "DME", "few MA", "many HE", "neoV", "EX fovea", "venous beading"]


with torch.no_grad():
  # Compute expert knwoledge prompts and forward text encoder
  text_embeds_dict, text_embeds = model.compute_text_embeddings(categories)
  print("%" + "-"*100 + "%")

  # Compute similarity matrix and logits
  logits = model.compute_logits(text_embeds, text_embeds)

  # Before obtaining softmax probs in relations, mask self-similarity in the diagonal of the similarity matrix
  mask = -100*torch.eye(logits.shape[0])
  probs_t = ((logits+mask)).softmax(dim=-1).detach().cpu().numpy().round(3)

  print(names)
  print(probs_t)

['mild diabetic retinopathy']
['severe diabetic retinopathy']
['proliferative diabetic retinopathy']
['diabetic macular edema']
['few microaneurysms']
['many haemorrhages']
['neovascularization']
['exudates in the fovea']
['venous beading']
%----------------------------------------------------------------------------------------------------%
['mildDR', 'sevDR', 'prolDR', 'DME', 'few MA', 'many HE', 'neoV', 'EX fovea', 'venous beading']
[[0.    0.    0.    0.    0.984 0.016 0.    0.    0.   ]
 [0.    0.    0.003 0.    0.    0.008 0.001 0.    0.988]
 [0.004 0.057 0.    0.002 0.003 0.031 0.843 0.    0.061]
 [0.004 0.118 0.299 0.    0.004 0.032 0.279 0.201 0.062]
 [0.984 0.    0.    0.    0.    0.016 0.    0.    0.   ]
 [0.35  0.15  0.035 0.    0.349 0.    0.006 0.002 0.107]
 [0.001 0.013 0.961 0.002 0.001 0.006 0.    0.    0.016]
 [0.011 0.047 0.031 0.267 0.015 0.496 0.096 0.    0.038]
 [0.    0.99  0.004 0.    0.    0.005 0.001 0.    0.   ]]
