In [None]:
!pip install torch_ema

Collecting torch_ema
  Downloading torch_ema-0.3-py3-none-any.whl.metadata (415 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->torch_ema)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->torch_ema)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->torch_ema)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->torch_ema)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->torch_ema)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->torch_ema)
  Downloading nvidia_cufft_cu12-11

In [2]:
import os
import random
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from collections import defaultdict, Counter
from dataclasses import dataclass
import pandas as pd
from tqdm import tqdm

import warnings

warnings.filterwarnings("ignore")

In [3]:
base_dir = Path("/kaggle/input/")
ham_dir = base_dir / "ham10000" / "ISIC-images"
bcn_dir = base_dir / "bcn20000"

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.environ["PYTHONHASHSEED"] = "42"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
@dataclass(frozen=True)
class HAMImage:
    path: Path
    identifier: str
    age: str
    sex: str
    diagnosis: str
    anatom_site: str

SEX_MAPPING = {
    "male": 0,
    "female": 1
}

DIAGNOSIS_MAPPING = {
    "Nevus": 0,
    "Melanoma, NOS": 1,
    "Pigmented benign keratosis": 2,
    "Dermatofibroma": 3,
    "Squamous cell carcinoma, NOS": 4,
    "Basal cell carcinoma": 5,
    "Solar or actinic keratosis": 6,
}

ANATOM_SITE_MAPPING = {
    "anterior torso": 0,
    "posterior torso": 1,
    "head/neck": 2,
    "upper extremity": 3,
    "lower extremity": 4,
    "palms/soles": 5,
    "oral/genital": 6
}

In [6]:
def concat_metadata(paths: List[Path]) -> pd.DataFrame:
    data = pd.DataFrame()
    for path in paths:
        metadata = pd.read_csv(path / "metadata.csv")
        metadata["base_path"] = path.as_posix()
        data = pd.concat([data, metadata])
    data = data.drop_duplicates(["isic_id"])
    return data

In [None]:
def load_metadata(metadata: pd.DataFrame) -> List[HAMImage]:
    metadata = metadata[pd.notnull(metadata["age_approx"])]
    metadata = metadata[pd.notnull(metadata["sex"])]
    metadata = metadata[pd.notnull(metadata["anatom_site_general"])]
    metadata = metadata[pd.notnull(metadata["diagnosis_3"])]
    images: List[HAMImage] = []
    for idx, row in metadata.iterrows():
        images.append(
            HAMImage(
                path=Path(row["base_path"]),
                identifier=row["isic_id"],
                age=row["age_approx"],
                sex=row["sex"],
                diagnosis=row["diagnosis_3"],
                anatom_site=row["anatom_site_general"]
            )
        )
    return images

In [None]:
def print_diagnosis_counts_by_sex(images: List[HAMImage]) -> None:
    counts_by_sex: dict[int, Counter[int]] = defaultdict(Counter)
    for img in images:
        counts_by_sex[img.sex][img.diagnosis] += 1

    all_diagnoses = sorted({d for ctr in counts_by_sex.values() for d in ctr})
    all_sexes = sorted(counts_by_sex)

    diag_col = "Diagnosis"
    sex_cols = [f"Sex {s}" for s in all_sexes]

    w_diag = max(len(diag_col), *(len(str(d)) for d in all_diagnoses))
    w_sex = {
        s: max(len(f"Sex {s}"), *(len(str(counts_by_sex[s][d])) for d in all_diagnoses))
        for s in all_sexes
    }

    header = f"{diag_col:<{w_diag}} " + " ".join(
        f"| {name:>{w_sex[s]}}" for s, name in zip(all_sexes, sex_cols)
    )
    sep = "-" * len(header)
    print(header)
    print(sep)

    for d in all_diagnoses:
        row = f"{str(d):<{w_diag}} " + " ".join(
            f"| {counts_by_sex[s][d]:>{w_sex[s]}}" for s in all_sexes
        )
        print(row)

In [9]:
def compute_mean_std(values: List[Any]) -> Tuple[float, float]:
    arr = np.array(values, dtype=np.float32)
    return float(arr.mean()), float(arr.std())

def normalize_meta(values: Sequence[float], means: Sequence[float], stds: Sequence[float]) -> torch.Tensor:
    normed = [(v - m) / s for v, m, s in zip(values, means, stds)]
    return torch.tensor(normed, dtype=torch.float32)

In [None]:
class HAMDiagnosisDataset(Dataset):
    def __init__(self, images: List[HAMImage], train: bool) -> None:
        self.images = images
        self.mean = [0.485, 0.456, 0.406]
        self.std = [0.229, 0.224, 0.225]

        ages = [img.age for img in images]
        self.mean_age, self.std_age = compute_mean_std(ages)
        
        sexes = [SEX_MAPPING[img.sex] for img in images]
        self.mean_sex, self.std_sex = compute_mean_std(sexes)

        sites = [ANATOM_SITE_MAPPING[img.anatom_site] for img in images]
        self.mean_site, self.std_site = compute_mean_std(sites)

        if not train:
            self.transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((500, 500)),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ])
        else:
            self.transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((500, 500)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std)
            ])

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        image = self.images[index]
        image_path = image.path / f"{image.identifier}.jpg"
        img = cv2.imread(image_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img)

        sex = SEX_MAPPING[image.sex]
        diagnosis = DIAGNOSIS_MAPPING[image.diagnosis]
        site = ANATOM_SITE_MAPPING[image.anatom_site]
        age = float(image.age)
    
        meta = normalize_meta(
            values=[sex, age, site],
            means=[self.mean_sex, self.mean_age, self.mean_site],
            stds=[self.std_sex, self.std_age, self.std_site]
        )
        return img, meta, diagnosis

In [87]:
metadata = concat_metadata(paths=[ham_dir])
images = load_metadata(metadata=metadata)
print_diagnosis_counts_by_sex(images=images)

Diagnosis                    | Sex female | Sex male
----------------------------------------------------
Basal cell carcinoma         |        209 |      362
Dermatofibroma               |         63 |       66
Melanoma, NOS                |        475 |      741
Nevus                        |       2929 |     3021
Pigmented benign keratosis   |        486 |      655
Solar or actinic keratosis   |         46 |       99
Squamous cell carcinoma, NOS |         79 |      150


In [88]:
from sklearn.model_selection import train_test_split
from torch.cuda.amp import autocast, GradScaler
from torch_ema import ExponentialMovingAverage

labels = [DIAGNOSIS_MAPPING[img.diagnosis] for img in images]

train_images, temp_images, train_labels, temp_labels = train_test_split(
    images, labels, test_size=0.2, stratify=labels, random_state=SEED
)

val_images, test_images, _, _ = train_test_split(
    temp_images, temp_labels, test_size=0.5, stratify=temp_labels, random_state=SEED
)

In [None]:
train = HAMDiagnosisDataset(train_images, train=True)
test = HAMDiagnosisDataset(test_images, train=False)
val = HAMDiagnosisDataset(val_images, train=False)

g = torch.Generator()
g.manual_seed(SEED)

trainloader = DataLoader(train, shuffle=True, batch_size=64, generator=g, num_workers=6)
testloader = DataLoader(test, shuffle=False, batch_size=32, num_workers=2)
valloader = DataLoader(val, shuffle=False, batch_size=32, num_workers=2)

In [90]:
from torchvision.models import ResNet50_Weights
resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)

In [91]:
for name, param in resnet.named_parameters():
    param.requires_grad = False

for name, module in resnet.named_modules():
    if isinstance(module, torch.nn.BatchNorm2d):
        module.eval()

In [None]:
num_features = resnet.fc.in_features
resnet.fc = torch.nn.Identity()

meta_net = torch.nn.Sequential(
    torch.nn.Linear(3, 16),
    torch.nn.SiLU(),
    torch.nn.Linear(16, 8),
    torch.nn.SiLU()
)

classifier = torch.nn.Sequential(
    torch.nn.Linear(num_features + 8, 1024),
    torch.nn.SiLU(),
    torch.nn.Dropout(0.5),
    torch.nn.Linear(1024, 7)
)

class HAMNet(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.resnet = resnet
        self.meta_net = meta_net
        self.classifier = classifier

    def forward(self, img: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:
        img_features = self.resnet(img)
        meta_features = self.meta_net(meta)
        x = torch.cat([img_features, meta_features], dim=1)
        return self.classifier(x)

In [93]:
from torchvision.models.resnet import ResNet
from torch.nn import CrossEntropyLoss

def evaluate(
    model: ResNet,
    loader: DataLoader,
    criterion: CrossEntropyLoss,
    ema: ExponentialMovingAverage,
) -> Tuple[int, int, List[np.ndarray], List[np.ndarray]]:
    model.eval()
    ema.store()
    ema.copy_to()
    correct = total = 0
    running_loss = 0.0

    all_preds, all_labels = [], []

    with torch.no_grad():
        for (imgs, meta, labels) in tqdm(loader, desc=f"Evaluation: "):
            imgs, meta, labels = imgs.to(device), meta.to(device), labels.to(device)

            logits = model(imgs, meta)
            loss = criterion(logits, labels)
            preds = logits.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            running_loss += loss.item() * imgs.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    ema.restore()
    eval_loss = running_loss / len(loader.dataset)
    return eval_loss, correct / total, all_preds, all_labels

In [94]:
from collections import Counter

label_counts = Counter(train_labels)
num_classes = len(list(set(train_labels)))

counts = torch.bincount(torch.tensor(train_labels), minlength=num_classes)
weights = counts.sum() / (num_classes * counts.clamp_min(1))

In [95]:
def unfreeze_layer(model: ResNet, layer: str) -> None:
    for name, param in model.named_parameters():
        if name.startswith(layer):
            param.requires_grad = True

    for name, module in model.named_modules():
        if name.startswith(layer) and isinstance(module, torch.nn.BatchNorm2d):
            module.train()

In [96]:
from collections import deque
from dataclasses import dataclass

from torch.optim import SGD

@dataclass(frozen=True)
class ParamGroup:
    layer: str
    epoch: int
    params: torch.nn.Parameter
    lr: float
    momentum: float
    decay: float

    @property
    def group(self) -> Dict[str, Any]:
        return {
            "params": self.params,
            "lr": self.lr,
            "momentum": self.momentum,
            "weight_decay": self.decay
        }


class ProgressiveUnfreezer:
    def __init__(self, model: ResNet, optimizer: SGD, params: List[ParamGroup]) -> None:
        self.model = model
        self.optimizer = optimizer
        self.params = deque(sorted(params, key=lambda x: x.epoch))
    
    def unfreeze(self, epoch: int) -> None:
        if not len(self.params):
            return None

        top = self.params[0]
        if epoch == top.epoch:
            unfreeze_layer(model=self.model, layer=top.layer)
            self.optimizer.add_param_group(top.group)
            self.params.popleft()
            print("Unfreezing layer...")

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

model = HAMNet()
model.to(device)

ema = ExponentialMovingAverage(model.parameters(), decay=0.999)

EPOCHS = 100

unfreeze_layer(model=model.resnet, layer="fc")
optimizer = SGD([{"params": model.classifier.parameters(), "lr": 1e-3, "momentum": 0.9, "weight_decay": 1e-3}])
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
criterion = CrossEntropyLoss(label_smoothing=0.05, weight=torch.tensor(weights).to(device))

patience = 7
best_loss = float("inf")
no_improve = 0

scaler = GradScaler()
unfreezer = ProgressiveUnfreezer(
    model.resnet, 
    optimizer,
    [
        ParamGroup(
            layer="layer4", 
            epoch=3, 
            params=model.resnet.layer4.parameters(), 
            lr=5e-3,
            momentum=0.9,
            decay=1e-5
        ),
        ParamGroup(
            layer="layer3",
            epoch=6,
            params=model.resnet.layer3.parameters(),
            lr=3e-3,
            momentum=0.9,
            decay=1e-5
        ),
    ]
)

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for imgs, meta, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, meta, labels = imgs.to(device), meta.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast():
            preds = model(imgs, meta)
            loss = criterion(preds, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        ema.update()

        running_loss += loss.item() * imgs.size(0)

    train_loss = running_loss / len(trainloader.dataset)

    scheduler.step()
    val_loss, val_acc, _, _ = evaluate(model, valloader, criterion, ema)

    print(f"Epoch {epoch+1}/{EPOCHS}.. "
          f"Train loss: {train_loss:.3f}.. "
          f"Val loss: {val_loss:.3f}.. "
          f"Accuracy: {val_acc:.3f}..")

    if val_loss < best_loss:
        no_improve = 0
        best_loss = val_loss
    else:
        no_improve += 1
        if no_improve >= patience:
            break

    unfreezer.unfreeze(epoch=epoch+1)

Epoch 1/100: 100%|██████████| 118/118 [01:37<00:00,  1.20it/s]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 1/100.. Train loss: 2.280.. Val loss: 2.283.. Accuracy: 0.017..


Epoch 2/100: 100%|██████████| 118/118 [01:37<00:00,  1.22it/s]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 2/100.. Train loss: 2.276.. Val loss: 2.271.. Accuracy: 0.017..


Epoch 3/100: 100%|██████████| 118/118 [01:37<00:00,  1.22it/s]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 3/100.. Train loss: 2.247.. Val loss: 2.259.. Accuracy: 0.021..
Unfreezing layer...


Epoch 4/100: 100%|██████████| 118/118 [01:41<00:00,  1.16it/s]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 4/100.. Train loss: 2.207.. Val loss: 2.221.. Accuracy: 0.030..


Epoch 5/100: 100%|██████████| 118/118 [01:42<00:00,  1.15it/s]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 5/100.. Train loss: 2.103.. Val loss: 2.149.. Accuracy: 0.110..


Epoch 6/100: 100%|██████████| 118/118 [01:42<00:00,  1.16it/s]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 6/100.. Train loss: 2.016.. Val loss: 2.052.. Accuracy: 0.376..
Unfreezing layer...


Epoch 7/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 7/100.. Train loss: 1.867.. Val loss: 1.916.. Accuracy: 0.515..


Epoch 8/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 8/100.. Train loss: 1.688.. Val loss: 1.751.. Accuracy: 0.604..


Epoch 9/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 9/100.. Train loss: 1.533.. Val loss: 1.633.. Accuracy: 0.620..


Epoch 10/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 10/100.. Train loss: 1.446.. Val loss: 1.588.. Accuracy: 0.659..


Epoch 11/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:17<00:00,  1.67it/s]


Epoch 11/100.. Train loss: 1.369.. Val loss: 1.518.. Accuracy: 0.680..


Epoch 12/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 12/100.. Train loss: 1.294.. Val loss: 1.507.. Accuracy: 0.678..


Epoch 13/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 13/100.. Train loss: 1.233.. Val loss: 1.571.. Accuracy: 0.682..


Epoch 14/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 14/100.. Train loss: 1.220.. Val loss: 1.452.. Accuracy: 0.694..


Epoch 15/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 15/100.. Train loss: 1.141.. Val loss: 1.491.. Accuracy: 0.701..


Epoch 16/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 16/100.. Train loss: 1.118.. Val loss: 1.439.. Accuracy: 0.717..


Epoch 17/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 17/100.. Train loss: 1.053.. Val loss: 1.377.. Accuracy: 0.759..


Epoch 18/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 18/100.. Train loss: 1.024.. Val loss: 1.393.. Accuracy: 0.773..


Epoch 19/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 19/100.. Train loss: 1.012.. Val loss: 1.394.. Accuracy: 0.775..


Epoch 20/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 20/100.. Train loss: 0.997.. Val loss: 1.389.. Accuracy: 0.773..


Epoch 21/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:17<00:00,  1.67it/s]


Epoch 21/100.. Train loss: 0.987.. Val loss: 1.399.. Accuracy: 0.795..


Epoch 22/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:17<00:00,  1.67it/s]


Epoch 22/100.. Train loss: 0.945.. Val loss: 1.404.. Accuracy: 0.758..


Epoch 23/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 23/100.. Train loss: 0.930.. Val loss: 1.411.. Accuracy: 0.788..


Epoch 24/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 24/100.. Train loss: 0.912.. Val loss: 1.366.. Accuracy: 0.774..


Epoch 25/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 25/100.. Train loss: 0.905.. Val loss: 1.375.. Accuracy: 0.812..


Epoch 26/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 26/100.. Train loss: 0.895.. Val loss: 1.363.. Accuracy: 0.827..


Epoch 27/100: 100%|██████████| 118/118 [02:05<00:00,  1.07s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 27/100.. Train loss: 0.869.. Val loss: 1.366.. Accuracy: 0.803..


Epoch 28/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 28/100.. Train loss: 0.882.. Val loss: 1.334.. Accuracy: 0.834..


Epoch 29/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 29/100.. Train loss: 0.862.. Val loss: 1.345.. Accuracy: 0.814..


Epoch 30/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 30/100.. Train loss: 0.868.. Val loss: 1.338.. Accuracy: 0.823..


Epoch 31/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:17<00:00,  1.67it/s]


Epoch 31/100.. Train loss: 0.864.. Val loss: 1.321.. Accuracy: 0.849..


Epoch 32/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 32/100.. Train loss: 0.856.. Val loss: 1.346.. Accuracy: 0.832..


Epoch 33/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 33/100.. Train loss: 0.841.. Val loss: 1.341.. Accuracy: 0.853..


Epoch 34/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 34/100.. Train loss: 0.839.. Val loss: 1.353.. Accuracy: 0.850..


Epoch 35/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 35/100.. Train loss: 0.839.. Val loss: 1.323.. Accuracy: 0.845..


Epoch 36/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.67it/s]


Epoch 36/100.. Train loss: 0.835.. Val loss: 1.337.. Accuracy: 0.852..


Epoch 37/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 37/100.. Train loss: 0.839.. Val loss: 1.357.. Accuracy: 0.846..


Epoch 38/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 38/100.. Train loss: 0.833.. Val loss: 1.319.. Accuracy: 0.841..


Epoch 39/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 39/100.. Train loss: 0.818.. Val loss: 1.352.. Accuracy: 0.859..


Epoch 40/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 40/100.. Train loss: 0.824.. Val loss: 1.321.. Accuracy: 0.855..


Epoch 41/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.64it/s]


Epoch 41/100.. Train loss: 0.824.. Val loss: 1.340.. Accuracy: 0.852..


Epoch 42/100: 100%|██████████| 118/118 [02:04<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 42/100.. Train loss: 0.812.. Val loss: 1.343.. Accuracy: 0.862..


Epoch 43/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]


Epoch 43/100.. Train loss: 0.818.. Val loss: 1.328.. Accuracy: 0.856..


Epoch 44/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.65it/s]


Epoch 44/100.. Train loss: 0.812.. Val loss: 1.357.. Accuracy: 0.867..


Epoch 45/100: 100%|██████████| 118/118 [02:05<00:00,  1.06s/it]
Evaluation: 100%|██████████| 30/30 [00:18<00:00,  1.66it/s]

Epoch 45/100.. Train loss: 0.812.. Val loss: 1.356.. Accuracy: 0.866..





In [98]:
test_loss, test_acc, _, _ = evaluate(model, testloader, criterion, ema)
print(f"Test Loss: {test_loss:.3f}.. "
      f"Test Accuracy: {test_acc:.3f}..")

Evaluation: 100%|██████████| 30/30 [00:17<00:00,  1.67it/s]

Test Loss: 1.317.. Test Accuracy: 0.853..





In [None]:
torch.save(model.state_dict(), "ham-net.pth")