In [18]:
import os
import random
from pathlib import Path
from typing import List, Tuple

import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from collections import defaultdict, Counter
from dataclasses import dataclass
import pandas as pd
from tqdm import tqdm

import warnings

warnings.filterwarnings("ignore")

In [None]:
from google.colab import drive

drive.mount('/content/drive')
data_dir = '/content/drive/MyDrive/ISIC-images/'

Mounted at /content/drive


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.environ["PYTHONHASHSEED"] = "42"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
@dataclass(frozen=True)
class HAM10000Image:
  identifier: str
  age: int
  sex: int
  diagnosis: int

SEX_MAPPING = {
  "male": 0,
  "female": 1
}

DIAGNOSIS_MAPPING = {
  "Malignant": 0,
  "Benign": 1
}

In [5]:
def load_metadata(path: Path) -> List[HAM10000Image]:
  data = pd.read_csv(path)
  data = data[pd.notnull(data["age_approx"])]
  data = data[pd.notnull(data["sex"])]
  data = data[data["diagnosis_1"].isin(["Malignant", "Benign"])]
  images: List[HAM10000Image] = []
  for idx, row in data.iterrows():
    images.append(
        HAM10000Image(
          identifier=row["isic_id"],
          age=row["age_approx"],
          sex=row["sex"],
          diagnosis=row["diagnosis_1"]
      )
    )
  return images

In [6]:
def print_diagnosis_counts_by_sex(images: List[HAM10000Image]) -> None:
    counts_by_sex: dict[int, Counter[int]] = defaultdict(Counter)
    for img in images:
        counts_by_sex[img.sex][img.diagnosis] += 1

    all_diagnoses = sorted({d for ctr in counts_by_sex.values() for d in ctr})
    all_sexes = sorted(counts_by_sex)

    diag_col = "Diagnosis"
    sex_cols = [f"Sex {s}" for s in all_sexes]

    w_diag = max(len(diag_col), *(len(str(d)) for d in all_diagnoses))
    w_sex = {
        s: max(len(f"Sex {s}"), *(len(str(counts_by_sex[s][d])) for d in all_diagnoses))
        for s in all_sexes
    }

    header = f"{diag_col:<{w_diag}} " + " ".join(
        f"| {name:>{w_sex[s]}}" for s, name in zip(all_sexes, sex_cols)
    )
    sep = "-" * len(header)
    print(header)
    print(sep)

    for d in all_diagnoses:
        row = f"{str(d):<{w_diag}} " + " ".join(
            f"| {counts_by_sex[s][d]:>{w_sex[s]}}" for s in all_sexes
        )
        print(row)

In [53]:
class HAM10000DiagnosisDataset(Dataset):
  def __init__(self, images: List[HAM10000Image], train: bool) -> None:
    self.images = images
    self.mean = [0.485, 0.456, 0.406]
    self.std = [0.229, 0.224, 0.225]

    ages = np.array([img.age for img in images])
    self.mean_age, self.std_age = ages.mean(), ages.std()

    if not train:
      self.transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(self.mean, self.std),
      ])
    else:
      self.transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(self.mean, self.std)
      ])

  def __len__(self) -> int:
    return len(self.images)

  def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
    image = self.images[index]
    image_path = Path(data_dir) / f"{image.identifier}.jpg"
    img = cv2.imread(image_path, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = self.transforms(img)

    sex = SEX_MAPPING[image.sex]
    diagnosis = DIAGNOSIS_MAPPING[image.diagnosis]
    age = int(image.age)

    sex = (sex - 0.5) * 2
    age = (age - self.mean_age) / self.std_age
    meta = torch.tensor([sex, age], dtype=torch.float32)
    return img, meta, diagnosis

In [None]:
images = load_metadata(Path(data_dir) / "metadata.csv")
print_diagnosis_counts_by_sex(images=images)

In [63]:
random.shuffle(images)

TRAIN_SPLIT = 0.8
TEST_SPLIT = 0.1

num_images = len(images)

train_index = int(num_images * TRAIN_SPLIT)
test_index = train_index + int(num_images * TEST_SPLIT)

train_images = images[:train_index]
test_images = images[train_index:test_index]
val_images = images[test_index:]

In [64]:
train = HAM10000DiagnosisDataset(train_images, train=True)
test = HAM10000DiagnosisDataset(test_images, train=False)
val = HAM10000DiagnosisDataset(val_images, train=False)
trainloader = DataLoader(train, shuffle=False, batch_size=64, num_workers=6)
testloader = DataLoader(test, shuffle=False, batch_size=32, num_workers=2)
valloader = DataLoader(val, shuffle=False, batch_size=32, num_workers=2)

In [73]:
from torchvision.models import ResNet50_Weights
resnet50 = models.resnet50(weights=ResNet50_Weights.DEFAULT)

In [74]:
for name, param in resnet50.named_parameters():
  if name.startswith("layer4"):
    param.requires_grad = True
  else:
    param.requires_grad = False

In [77]:
for name, module in resnet50.named_modules():
  if not name.startswith("layer4"):
    if isinstance(module, torch.nn.BatchNorm2d):
        module.eval()

In [67]:
num_features = resnet50.fc.in_features
resnet50.fc = torch.nn.Identity()

meta_net = torch.nn.Sequential(
  torch.nn.Linear(2, 16),
  torch.nn.ReLU(),
  torch.nn.Linear(16, 8),
  torch.nn.ReLU()
)

classifier = torch.nn.Sequential(
  torch.nn.Linear(num_features + 8, 1024),
  torch.nn.ReLU(),
  torch.nn.Dropout(0.1),
  torch.nn.Linear(1024, 128),
  torch.nn.ReLU(),
  torch.nn.Dropout(0.1),
  torch.nn.Linear(128, 2)
)

class HAMNet(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.resnet = resnet50
    self.meta_net = meta_net
    self.classifier = classifier

  def forward(self, img: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:
    img_features = self.resnet(img)
    meta_features = self.meta_net(meta)
    x = torch.cat([img_features, meta_features], dim=1)
    return self.classifier(x)

In [68]:
from torchvision.models.resnet import ResNet
from torch.nn import CrossEntropyLoss

def evaluate(
  model: ResNet,
  loader: DataLoader,
  criterion: CrossEntropyLoss
) -> Tuple[int, int]:
  model.eval()
  correct = total = 0
  running_loss = 0.0

  with torch.no_grad():
    for (imgs, meta, labels) in loader:
      imgs, meta, labels = imgs.to(device), meta.to(device), labels.to(device)

      logits = model(imgs, meta)
      loss = criterion(logits, labels)
      preds = logits.argmax(dim=1)

      correct += (preds == labels).sum().item()
      total += labels.size(0)

      running_loss += loss.item() * imgs.size(0)

  eval_loss = running_loss / len(loader.dataset)
  return eval_loss, correct / total

In [None]:
model = HAMNet()
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.01,
    weight_decay=1e-4
)

EPOCHS = 100

no_improve = 0
patience, best_loss = 7, float('inf')

criterion = CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.01,
    steps_per_epoch=len(trainloader),
    epochs=EPOCHS,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=15.0,
    final_div_factor=1e3,
)

model.to(device)

torch.backends.cudnn.benchmark = True

for epoch in range(EPOCHS):
  model.train()
  running_loss = 0.0

  for imgs, meta, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
    imgs, meta, labels = imgs.to(device), meta.to(device), labels.to(device)

    optimizer.zero_grad()
    logits = model(imgs, meta)
    loss = criterion(logits, labels)

    loss.backward()
    optimizer.step()
    scheduler.step()

    running_loss += loss.item() * imgs.size(0)

  val_loss, val_acc = evaluate(model, valloader, criterion)

  if val_loss < best_loss:
    best_loss = val_loss
    no_improve = 0
  else:
    no_improve += 1
    if no_improve == patience:
      break

  train_loss = running_loss / len(trainloader.dataset)

  print(f"Epoch {epoch+1}/{EPOCHS}.. "
        f"Train loss: {train_loss:.3f}.. "
        f"Val loss: {val_loss:.3f}.. "
        f"Accuracy: {val_acc:.3f}..")

In [None]:
test_loss, test_acc = evaluate(model, testloader, criterion)
print(f"Test loss: {test_loss:.3f}.. "
      f"Accuracy: {test_acc:.3f}..")