In [1]:
import albumentations as A
import numpy as np
from albumentations.pytorch import ToTensorV2
import os
from collections import OrderedDict
from timm.models.layers.activations import *
import torch
import torch.nn as nn
import torch.nn.functional as F

image_size=224
num_classes=2


def get_training_augmentation():
    augmentations_train = A.Compose(
        [
            A.RandomResizedCrop(224, 224, scale=(0.8, 1.0)),
            A.Transpose(p=0.5),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightness(limit=0.2, p=0.75),
            A.RandomContrast(limit=0.2, p=0.75),
            
            A.OneOf([A.MotionBlur(blur_limit=5),
                     A.MedianBlur(blur_limit=5),
                     A.GaussianBlur(blur_limit=5),
                     A.GaussNoise(var_limit=(5.0, 30.0)),
                     ], p=0.7),

            A.OneOf([A.OpticalDistortion(distort_limit=1.0),
                     A.GridDistortion(num_steps=5, distort_limit=1.),
                     A.ElasticTransform(alpha=3),
                    ], p=0.7),

            A.CLAHE(clip_limit=4.0, p=0.7),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
            A.Resize(image_size, image_size),
            A.Cutout(max_h_size=int(image_size * 0.375), max_w_size=int(image_size * 0.375), num_holes=1, p=0.7),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ],
    )
    return lambda img: augmentations_train(image=np.array(img))


def get_test_augmentation():
    augmentations_val = A.Compose(
        [
            A.SmallestMaxSize(256),
            A.CenterCrop(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ],
    )
    return lambda img: augmentations_val(image=np.array(img))

In [2]:
# Based on https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py
class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, alpha, T, criterion):
        super().__init__()
        self.criterion = criterion
        self.KLDivLoss = nn.KLDivLoss(reduction="batchmean")
        self.alpha = alpha
        self.T = T

    def forward(self, input, target, teacher_target):
        loss = self.KLDivLoss(
            F.log_softmax(input / self.T, dim=1),
            F.softmax(teacher_target / self.T, dim=1),
        ) * (self.alpha * self.T * self.T) + self.criterion(input, target) * (
            1.0 - self.alpha
        )
        return loss


class MixUpAugmentationLoss(nn.Module):
    def __init__(self, criterion):
        super().__init__()
        self.criterion = criterion

    def forward(self, input, target, *args):
        # Validation step
        if isinstance(target, torch.Tensor):
            return self.criterion(input, target, *args)
        target_a, target_b, lmbd = target
        return lmbd * self.criterion(input, target_a, *args) + (
            1 - lmbd
        ) * self.criterion(input, target_b, *args)


# Based on https://github.com/pytorch/pytorch/issues/7455
class LabelSmoothingLoss(nn.Module):
    def __init__(self, n_classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = n_classes
        self.dim = dim

    def forward(self, output, target, *args):
        output = output.log_softmax(dim=self.dim)
        with torch.no_grad():
            # Create matrix with shapes batch_size x n_classes
            true_dist = torch.zeros_like(output)
            # Initialize all elements with epsilon / N - 1
            true_dist.fill_(self.smoothing / (self.cls - 1))
            # Fill correct class for each sample in the batch with 1 - epsilon
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * output, dim=self.dim))

In [3]:
import warnings
from typing import Dict

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder


class TB_CXR(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        # We need to specify a number of classes there to avoid the RuntimeError
        # See https://github.com/PyTorchLightning/pytorch-lightning/issues/3006
        # However, we will get another warning and it should be handled in forward steps
        self.metric = pl.metrics.Accuracy(num_classes=self.config.num_classes)
        #dim_feats = self.model.classifier.in_features  # =2048 .fc or .classifier
        nb_classes = self.config.num_classes
        fc = nn.Sequential(OrderedDict([#('fc1', nn.Linear(1280, 1000, bias=True)),
                                 ('fc1', nn.Linear(2048, 1000, bias=True)),
							     ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('dropout1', nn.Dropout(0.7)),
                                 ('fc2', nn.Linear(1000, 512)),
								 ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('swish1', Swish()),
								 ('dropout2', nn.Dropout(0.5)),
								 ('fc3', nn.Linear(512, 128)),
								 ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
							     ('swish2', Swish()),
								 ('fc4', nn.Linear(128, num_classes)),
								 ('output', nn.Softmax(dim=1))
							 ]))
        # connect base model (EfficientNet_B0) with modified classifier layer
        
        #dim_feats = self.teacher.classifier.in_features  # =2048 .fc cho cac model khac
        nb_classes = self.config.num_classes
        self.model.fc = fc
        #self.model.classifier = nn.Linear(dim_feats, nb_classes)

    def forward(self, x):
        return self.model(x)

    def setup(self, stage):
        if self.config.use_smoothing:
            self.criterion = LabelSmoothingLoss(
                self.config.num_classes, self.config.smoothing,
            )
        else:
            self.criterion = nn.CrossEntropyLoss()

        if self.config.use_mixup:
            self.criterion = MixUpAugmentationLoss(self.criterion)

    def on_epoch_start(self):
        self.previous_batch = [None, None]

    def training_step(self, batch, *args):
        x, y = batch[0]["image"], batch[1]
        if self.args.use_mixup:
            mixup_x, *mixup_y = self.mixup_batch(x, y, *self.previous_batch)
            logits = self(mixup_x)
            loss = self.criterion(logits, mixup_y)
        else:
            logits = self(x)
            loss = self.criterion(logits, y)
        # We ignore a warning about a mismatch between a number of predicted classes
        # and a number of initialized for Accuracy class
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            accuracy = self.metric(logits.argmax(dim=-1), y)
        tensorboard_logs = {"train_loss": loss, "train_acc": accuracy}
        self.previous_batch = [x, y]

        return {"loss": loss, "progress_bar": tensorboard_logs, "log": tensorboard_logs}

    def validation_step(self, batch, *args):
        x, y = batch[0]["image"], batch[1]
        logits = self(x)
        val_loss = self.criterion(logits, y)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            val_accuracy = self.metric(logits.argmax(dim=-1), y)
        return {"val_loss": val_loss, "val_acc": val_accuracy}

    def test_step(self, batch, *args):
        x, y = batch[0]["image"], batch[1]
        logits = self(x)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            test_accuracy = self.metric(logits.argmax(dim=-1), y)
        return {"test_acc": test_accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_accuracy = torch.stack([x["val_acc"] for x in outputs]).mean()
        tensorboard_logs = {"val_loss": avg_loss, "val_acc": avg_accuracy}
        return {
            "avg_val_loss": avg_loss,
            "avg_val_acc": avg_accuracy,
            "log": tensorboard_logs,
        }

    def test_epoch_end(self, outputs):
        avg_accuracy = torch.stack([x["test_acc"] for x in outputs]).mean()
        return {"avg_test_acc": avg_accuracy.item()}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.lr)
        if self.config.use_cosine_scheduler:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.config.max_epochs, eta_min=0.0,
            )
        else:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, milestones=self.config.milestones,
            )
        return [optimizer], [scheduler]

    def train_dataloader(self):
        train_dataset = ImageFolder(
            os.path.join('/home/linh/Downloads/TB', "train"),
            transform=get_training_augmentation(),
        )

        return DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.workers,
            pin_memory=True,
        )

    def val_dataloader(self):
        val_dataset = ImageFolder(
            os.path.join('/home/linh/Downloads/TB', "test"),
            transform=get_test_augmentation(),
        )
        return DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=self.config.workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return self.val_dataloader()

    def optimizer_step(self, epoch, batch_idx, optimizer, *args, **kwargs):
        # Learning Rate warm-up
        if self.config.warmup != -1 and epoch < self.config.warmup:
            lr = self.config.lr * (epoch + 1) / self.config.warmup
            for pg in optimizer.param_groups:
                pg["lr"] = lr

        self.logger.log_metrics({"lr": optimizer.param_groups[0]["lr"]}, step=epoch)
        optimizer.step()
        optimizer.zero_grad()

    def mixup_batch(self, x, y, x_previous, y_previous):
        lmbd = (
            np.random.beta(self.config.mixup_alpha, self.config.mixup_alpha)
            if self.config.mixup_alpha > 0
            else 1
        )
        if x_previous is None:
            x_previous = torch.empty_like(x).copy_(x)
            y_previous = torch.empty_like(y).copy_(y)
        batch_size = x.size(0)
        index = torch.randperm(batch_size)
        # If current batch size != previous batch size, we take only a part of the previous batch
        x_previous = x_previous[:batch_size, ...]
        y_previous = y_previous[:batch_size, ...]
        x_mixed = lmbd * x + (1 - lmbd) * x_previous[index, ...]
        y_a, y_b = y, y_previous[index]
        return x_mixed, y_a, y_b, lmbd


class TB_CXR_KD(TB_CXR):
    def __init__(self, model, teacher, config):
        super().__init__(model, config)
        self.teacher = teacher
        fc = nn.Sequential(OrderedDict([#('fc1', nn.Linear(1280, 1000, bias=True)),
                                 ('fc1', nn.Linear(2048, 1000, bias=True)),
							     ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('dropout1', nn.Dropout(0.7)),
                                 ('fc2', nn.Linear(1000, 512)),
								 ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('swish1', Swish()),
								 ('dropout2', nn.Dropout(0.5)),
								 ('fc3', nn.Linear(512, 128)),
								 ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
							     ('swish2', Swish()),
								 ('fc4', nn.Linear(128, num_classes)),
								 ('output', nn.Softmax(dim=1))
							 ]))
        # connect base model (EfficientNet_B0) with modified classifier layer
        
        #dim_feats = self.teacher.classifier.in_features  # =2048 .fc cho cac model khac
        nb_classes = self.config.num_classes
        self.teacher.fc = fc
        #self.teacher.classifier = nn.Linear(dim_feats, nb_classes)
        teacher_checkpoint = torch.load("/home/linh/Downloads/TB/weights/EfficientNet_B1_Mod.pth")
        self.teacher.load_state_dict(teacher_checkpoint["model_state_dict"])


    def setup(self, stage):
        criterion = (
            LabelSmoothingLoss(self.config.num_classes, self.config.smoothing)
            if self.config.use_smoothing
            else nn.CrossEntropyLoss()
        )
        self.criterion = KnowledgeDistillationLoss(
            self.config.distill_alpha, self.config.distill_temperature, 
            criterion=criterion,
        )
        if self.config.use_mixup:
            self.criterion = MixUpAugmentationLoss(self.criterion)
        self.teacher.eval()

    def training_step(self, batch, *args):
        x, y = batch[0]["image"], batch[1]
        with torch.no_grad():
            teacher_output = self.teacher(x)

        if self.config.use_mixup:
            mixup_x, *mixup_y = self.mixup_batch(x, y, *self.previous_batch)
            logits = self(mixup_x)
            loss = self.criterion(logits, mixup_y, teacher_output)
        else:
            logits = self(x)
            loss = self.criterion(logits, y, teacher_output)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            accuracy = self.metric(logits.argmax(dim=-1), y)
        tensorboard_logs = {"train_loss": loss, "train_acc": accuracy}

        return {"loss": loss, "progress_bar": tensorboard_logs, "log": tensorboard_logs}

    def validation_step(self, batch, *args):
        x, y = batch[0]["image"], batch[1]
        logits = self(x)
        with torch.no_grad():
            teacher_output = self.teacher(x)
        val_loss = self.criterion(logits, y, teacher_output)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            val_accuracy = self.metric(logits.argmax(dim=-1), y)
        return {"val_loss": val_loss, "val_acc": val_accuracy}

    def test_step(self, batch, *args):
        x, y = batch[0]["image"], batch[1]
        logits = self(x)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            test_accuracy = self.metric(logits.argmax(dim=-1), y)
        return {"test_acc": test_accuracy}



In [4]:
from dataclasses import dataclass

@dataclass
class Config:
    image_size: int = 224
    workers: int = 4 # Number of data loading workers
    use_smoothing: bool = True # Use label smoothing trick
    smoothing: float = 0.2 # Coefficient for label smoothing (from 0.0 (no smoothing) to 1.0)
    use_mixup: bool = True # Use mixup augmentation during training
    mixup_alpha: float = 0.2 # Alpha value for mixup augmentation
    use_cosine_scheduler: bool = True # Use Cosine LR Scheduler instead of MultiStep
    batch_size: int = 64 # Mini-batch size
    lr: float = 1e-4 # Initial learning rate
    milestones: tuple = (15, 30) # Milestones for dropping the LR
    warmup: int = 6 # Number of epochs to warm up the LR. -1 to turn off
    max_epochs: int = 40 # Max number of epochs
    amp_level: str = 'O0' # Apex optimization level
    num_classes: int = 2 # Number of classes in the dataset
    use_knowledge_distillation: bool = True # Use knowledge distillation from resnet-50
    distill_alpha: float = 0.5 # Distillation strength
    distill_temperature: int = 20 # Temperature hyper-parameter to make the outputs smoother for KD


In [5]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 21033), started 1 day, 21:12:43 ago. (Use '!kill 21033' to kill it.)

In [None]:
from pytorch_lightning import (
    Trainer,
    seed_everything,
)
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision.models import resnet18, resnet50
from timm import create_model
seed_everything(42)

config = Config()

checkpoint_callback = ModelCheckpoint(monitor="avg_val_acc", mode="max")
trainer = Trainer(
    gpus=1,
    amp_level=config.amp_level,
    amp_backend='apex',
    precision=16 if config.amp_level != 'O0' else 32,
    deterministic=True,
    benchmark=False,
    checkpoint_callback=checkpoint_callback,
    max_epochs=config.max_epochs
)

# create model
#model = resnet18(pretrained=True)
model = create_model('efficientnet_b0', pretrained=True, drop_rate=0.2)
fc = nn.Sequential(OrderedDict([#('fc1', nn.Linear(1280, 1000, bias=True)),
                                 ('fc1', nn.Linear(2048, 1000, bias=True)),
							     ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('dropout1', nn.Dropout(0.7)),
                                 ('fc2', nn.Linear(1000, 512)),
								 ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('swish1', Swish()),
								 ('dropout2', nn.Dropout(0.5)),
								 ('fc3', nn.Linear(512, 128)),
								 ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
							     ('swish2', Swish()),
								 ('fc4', nn.Linear(128, num_classes)),
								 ('output', nn.Softmax(dim=1))
							 ]))
# connect base model (EfficientNet_B0) with modified classifier layer
model.fc = fc
if config.use_knowledge_distillation:
    #teacher_model = resnet50(pretrained=False)
    teacher_model = create_model('efficientnet_b1', pretrained=True, drop_rate=0.2)
    fc = nn.Sequential(OrderedDict([#('fc1', nn.Linear(1280, 1000, bias=True)),
                                 ('fc1', nn.Linear(2048, 1000, bias=True)),
							     ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('dropout1', nn.Dropout(0.7)),
                                 ('fc2', nn.Linear(1000, 512)),
								 ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('swish1', Swish()),
								 ('dropout2', nn.Dropout(0.5)),
								 ('fc3', nn.Linear(512, 128)),
								 ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
							     ('swish2', Swish()),
								 ('fc4', nn.Linear(128, num_classes)),
								 ('output', nn.Softmax(dim=1))
							 ]))
# connect base model (EfficientNet_B0) with modified classifier layer
    teacher_model.fc = fc
    model = TB_CXR_KD(model, teacher_model, config)
else:
    model = TB_CXR(model, config)

trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                  | Params
----------------------------------------------------
0 | model     | EfficientNet          | 7 M   
1 | metric    | Accuracy              | 0     
2 | teacher   | EfficientNet          | 10 M  
3 | criterion | MixUpAugmentationLoss | 0     


Validation sanity check: 0it [00:00, ?it/s]
                                           
Validating:  17%|█▋        | 1/6 [00:01<00:05,  1.19s/it][A
Validating:  33%|███▎      | 2/6 [00:01<00:03,  1.03it/s][A
Epoch 0:   0%|          | 0/48 [00:00<?, ?it/s]          [A

  "blur_limit and sigma_limit minimum value can not be both equal to 0. "


Epoch 0:  88%|████████▊ | 42/48 [00:39<00:05,  1.07it/s, loss=743.496, v_num=5, train_loss=728, train_acc=0]     
Validating: 0it [00:00, ?it/s][A
Epoch 0:  90%|████████▉ | 43/48 [00:40<00:04,  1.07it/s, loss=743.496, v_num=5, train_loss=728, train_acc=0]
Epoch 0:  92%|█████████▏| 44/48 [00:40<00:03,  1.07it/s, loss=743.496, v_num=5, train_loss=728, train_acc=0]
Epoch 0:  94%|█████████▍| 45/48 [00:44<00:02,  1.02it/s, loss=743.496, v_num=5, train_loss=728, train_acc=0]
Epoch 0: 100%|██████████| 48/48 [00:44<00:00,  1.08it/s, loss=743.496, v_num=5, train_loss=728, train_acc=0]
Epoch 1:  88%|████████▊ | 42/48 [00:37<00:05,  1.11it/s, loss=729.183, v_num=5, train_loss=719, train_acc=0]     
Validating: 0it [00:00, ?it/s][A
Validating:  17%|█▋        | 1/6 [00:00<00:03,  1.38it/s][A
Epoch 1:  92%|█████████▏| 44/48 [00:39<00:03,  1.11it/s, loss=729.183, v_num=5, train_loss=719, train_acc=0]
Epoch 1:  96%|█████████▌| 46/48 [00:42<00:01,  1.08it/s, loss=729.183, v_num=5, train_loss=719, tr

In [None]:
trainer.test()