In [42]:
import torch
import os
import json
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from PIL import Image
from transformers import AlignProcessor, AlignModel, AutoTokenizer, AutoProcessor
from util import *

In [43]:
mode = 'trained'

In [44]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AlignProcessor.from_pretrained("kakaobrain/align-base")
model = AlignModel.from_pretrained("kakaobrain/align-base")
model.to(device)

if mode == 'trained':
    load_model(model, "../models/align-model-trained-softmax.pth")
model.eval()

Model loaded from ../models/align-model-trained-softmax.pth


AlignModel(
  (text_model): AlignTextModel(
    (embeddings): AlignTextEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): AlignTextEncoder(
      (layer): ModuleList(
        (0-11): 12 x AlignTextLayer(
          (attention): AlignTextAttention(
            (self): AlignTextSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): AlignTextSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerN

In [45]:
def align_text_image(description, image_path):
    image = Image.open(image_path)
    inputs = processor(text=description, images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs

In [46]:
def calculate_cosine_similarity(embedding1, embedding2):
    similarity = torch.nn.CosineSimilarity(dim=1, eps=1e-6)(embedding1, embedding2)
    similarity = similarity.cpu().numpy()
    return similarity

In [47]:
columns = ["Disease", "Average Similarity"]
df = pd.DataFrame(columns=columns)

In [48]:
image_directory = get_train_image_directory()
disease_descriptions = get_disease_skin_symptoms()

subdirectories = os.listdir(image_directory)
keys = list(disease_descriptions.keys())


for subdirectory in subdirectories:
    print(f"Processing {subdirectory}")

    if subdirectory not in keys:
        continue

    description = disease_descriptions[subdirectory]

    total_similarity = 0
    count = 0

    for image_name in os.listdir(os.path.join(image_directory, subdirectory)):

        image_path = os.path.join(image_directory, subdirectory, image_name)
        outputs = align_text_image(description, image_path)

        text_embeddings = outputs.text_embeds
        image_embeddings = outputs.image_embeds

        similarity = calculate_cosine_similarity(text_embeddings, image_embeddings)

        total_similarity += similarity
        count += 1
        
    average_similarity = total_similarity / count
    print(f"Average similarity for {subdirectory}: {average_similarity.item()}")

    new_row = {"Disease": subdirectory, "Average Similarity": average_similarity.item()}
    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)


print("Done!")

Processing actinic-comedones
Average similarity for actinic-comedones: 0.013586750254034996
Processing athlete's-foot


  df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)


Average similarity for athlete's-foot: 0.028204061090946198
Processing basal-cell-carcinoma
Average similarity for basal-cell-carcinoma: -0.03374653682112694
Processing cellulitis
Average similarity for cellulitis: 0.0007615219801664352
Processing chickenpox
Average similarity for chickenpox: 0.02683234214782715
Processing cutaneous-larva-migrans
Average similarity for cutaneous-larva-migrans: 0.018109893426299095
Processing erythema-ab-igne
Average similarity for erythema-ab-igne: 0.02143237181007862
Processing herpes
Average similarity for herpes: -0.025620965287089348
Processing hidrocystoma
Average similarity for hidrocystoma: -0.017715953290462494
Processing impetigo
Average similarity for impetigo: 0.006046549882739782
Processing melanotic-macule
Average similarity for melanotic-macule: -0.0026231997180730104
Processing nail-fungus
Average similarity for nail-fungus: 0.028409769758582115
Processing perleche
Average similarity for perleche: -0.030819127336144447
Processing ringwor

In [49]:
df.sort_values(by="Average Similarity", ascending=False, inplace=True)
print(df)

                    Disease  Average Similarity
13                 ringworm            0.032346
11              nail-fungus            0.028410
1            athlete's-foot            0.028204
4                chickenpox            0.026832
16           spider-angioma            0.025637
6          erythema-ab-igne            0.021432
17           sycosis-barbae            0.020058
5   cutaneous-larva-migrans            0.018110
18              tinea-beard            0.016494
0         actinic-comedones            0.013587
9                  impetigo            0.006047
19              venous-lake            0.003821
3                cellulitis            0.000762
10         melanotic-macule           -0.002623
14                  rosacea           -0.010166
8              hidrocystoma           -0.017716
15                skin-tags           -0.023232
7                    herpes           -0.025621
12                 perleche           -0.030819
2      basal-cell-carcinoma           -0

In [50]:
csv_path = ""
if mode == 'trained':
    csv_path = "../results/average-similarity-softmax-align.csv"
else:
    csv_path = "../results/average-similarity-no-training-align.csv"

df.to_csv(csv_path, index=False)

In [51]:
# A test to see how many images are classified correctly
disease_descriptions = get_disease_skin_symptoms()
image_directory = get_train_image_directory()

labels = list(disease_descriptions.keys())
descriptions = list(disease_descriptions.values())
accuracies = {}

for subdir in os.listdir(image_directory):
    correct = 0
    total = 0

    print(f"Processing {subdir}")

    for image in os.listdir(os.path.join(image_directory, subdir)):
        image_path = os.path.join(image_directory, subdir, image)
        image = Image.open(image_path)
        inputs = processor(text=descriptions, images=image, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        idx = torch.argmax(probs)

        if labels[idx] == subdir:
            correct += 1
        total += 1

    print(f"Accuracy: {(correct * 100) / total} %")
    accuracies[subdir] = (correct * 100) / total

Processing actinic-comedones




Accuracy: 3.5 %
Processing athlete's-foot
Accuracy: 9.5 %
Processing basal-cell-carcinoma
Accuracy: 0.0 %
Processing cellulitis
Accuracy: 1.0 %
Processing chickenpox
Accuracy: 3.5 %
Processing cutaneous-larva-migrans
Accuracy: 3.5 %
Processing erythema-ab-igne
Accuracy: 7.0 %
Processing herpes
Accuracy: 0.0 %
Processing hidrocystoma
Accuracy: 0.0 %
Processing impetigo
Accuracy: 2.5 %
Processing melanotic-macule
Accuracy: 5.5 %
Processing nail-fungus
Accuracy: 10.0 %
Processing perleche
Accuracy: 0.0 %
Processing ringworm
Accuracy: 6.5 %
Processing rosacea
Accuracy: 15.789473684210526 %
Processing skin-tags
Accuracy: 0.0 %
Processing spider-angioma
Accuracy: 7.5 %
Processing sycosis-barbae
Accuracy: 6.0 %
Processing tinea-beard
Accuracy: 4.0 %
Processing venous-lake
Accuracy: 4.5 %


In [53]:
# Sorting and saving accuracies
save_path = ""
if mode == 'trained':
    save_path = "../results/accuracies-trained-softmax-align.json"
else:
    save_path = "../results/accuracies-no-training-align.json"
    
accuracies = {k: v for k, v in sorted(accuracies.items(), key=lambda item: item[1], reverse=True)}
with open(save_path, "w") as f:
    f.write(json.dumps(accuracies, indent=4))