In [1]:
import os
import cv2
import glob
import timm
import random
import numpy as np
import pandas as pd
from collections import Counter
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics.functional import mean_squared_error, mean_absolute_error, accuracy, f1_score

batch_size = 128

df = pd.read_excel("data/US_fibrosis_stage_dataset.xlsx", engine="openpyxl")
# only HBV patients
# df = df[df.Etiology == 1].reset_index(drop=True)
df = df.loc[(df.AST < 100) | (df.ALT < 100)].reset_index(drop=True)
df = df.loc[:, ["ID", "kPa_fib"]].dropna().reset_index(drop=True)
df.ID = df.ID.map(lambda x: str(x).zfill(8))

flist = os.listdir("data/roi_sampled/")
id_list = list(map(lambda x: x.split("_")[0].zfill(8), flist))

image_df = pd.DataFrame(glob.glob(os.path.join("data", "roi_sampled", "*.jpg")), columns=["image_path"])
image_df.loc[:, "ID"] = image_df.image_path.map(lambda x: x.split("/")[-1].split("-")[0])

df = pd.merge(df, image_df, on="ID", how="inner")
df.head()

Unnamed: 0,ID,kPa_fib,image_path
0,266195,3.8,data/roi_sampled/00266195-0.jpg
1,266195,3.8,data/roi_sampled/00266195-1.jpg
2,266195,3.8,data/roi_sampled/00266195-10.jpg
3,266195,3.8,data/roi_sampled/00266195-11.jpg
4,266195,3.8,data/roi_sampled/00266195-12.jpg


In [2]:
ids = df.ID.drop_duplicates().reset_index(drop=True)

train_id, test_id = train_test_split(ids, test_size=0.15, random_state=42)
train_id, valid_id = train_test_split(train_id, test_size=0.15, random_state=42)

train_df = df[df.ID.isin(train_id)].reset_index(drop=True)
valid_df = df[df.ID.isin(valid_id)].reset_index(drop=True)
test_df = df[df.ID.isin(test_id)].reset_index(drop=True)

print("Train: ", len(train_df.ID.drop_duplicates()))
print("Valid: ", len(valid_df.ID.drop_duplicates()))
print("Test: ", len(test_df.ID.drop_duplicates()))

Train:  852
Valid:  151
Test:  177


In [3]:
def define_augmentation(w, h):
    train_transforms = A.Compose([ 
        A.Resize(width=w, height=h, p=1.0),
        A.OneOf([
            A.Downscale(),
        ], p=0.5),        
        
        A.HorizontalFlip(p=0.5),
        
        A.Affine(p=0.8),
        
        A.OneOf([
            A.RandomBrightnessContrast(),
            A.RandomBrightness(),
            A.RandomContrast()
        ], p=0.5),
        
        A.Normalize(p=1.0),
        ToTensorV2()
    ])

    valid_transforms = A.Compose([ 
        A.Resize(width=w, height=h, p=1.0),
        A.Normalize(p=1.0),
        ToTensorV2()
    ])

    return train_transforms, valid_transforms


class SonographyDataset(Dataset):
    def __init__(self, df, transform, train_mode=False):
        self.df = df
        self.transform = transform
        self.train_mode = train_mode
        
        
    def __len__(self):
        return len(self.df)

    
    def __getitem__(self, idx):
        image = cv2.imread(self.df.loc[idx, "image_path"])
        image = self.transform(image=image)
        
        y = self.df.loc[idx, "kPa_fib"]
        if self.train_mode:
            y += np.random.rand(1)[0] - 0.5

        return image['image'], torch.tensor(y).log().float()
    
    
train_transform, valid_transform = define_augmentation(w=224, h=224)

train_dataset = SonographyDataset(train_df, train_transform, train_mode=True)
valid_dataset = SonographyDataset(valid_df, valid_transform)
test_dataset = SonographyDataset(test_df, valid_transform)


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, 
                              num_workers=12, prefetch_factor=10,
                              pin_memory=True)

valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, 
                              num_workers=12, prefetch_factor=10,
                              pin_memory=True)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, 
                              num_workers=12, prefetch_factor=10,
                              pin_memory=True)



In [4]:
class ResizeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        
        return x

    
class BasicBlock(nn.Module):
    def __init__(self, in_planes, stride=1):
        super().__init__()
        planes = in_planes*stride

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        if stride == 1:
            self.shortcut = nn.Sequential()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        
        return out
    
    
class Encoder(nn.Module):
    def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
        super().__init__()
        self.in_planes = 64
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(BasicBlock, 64, num_Blocks[0], stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, num_Blocks[1], stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, num_Blocks[2], stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, num_Blocks[3], stride=2)
        self.linear = nn.Linear(512, 4 * z_dim)
        
        self.z_mu1 = nn.Linear(z_dim * 4, z_dim * 2)
        self.z_mu2 = nn.Linear(z_dim * 2, z_dim)
        
        self.z_log_var1 = nn.Linear(z_dim * 4, z_dim * 2)
        self.z_log_var2 = nn.Linear(z_dim * 2, z_dim)
        
        self.r_mu1 = nn.Linear(z_dim * 4, z_dim * 2)
        self.r_mu2 = nn.Linear(z_dim * 2, 1)
        
        self.r_log_var1 = nn.Linear(z_dim * 4, z_dim * 2)
        self.r_log_var2 = nn.Linear(z_dim * 2, 1)
        
        
    def _make_layer(self, BasicBlock, planes, num_Blocks, stride):
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        
        for stride in strides:
            layers += [BasicBlock(self.in_planes, stride)]
            self.in_planes = planes
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        
        z_mu = F.tanh(self.z_mu1(x))
        z_mu = self.z_mu2(z_mu)
        
        z_log_var = F.tanh(self.z_log_var1(x))
        z_log_var = self.z_log_var2(z_log_var)
        
        r_mu = F.tanh(self.r_mu1(x))
        r_mu = self.r_mu2(r_mu)
        
        r_log_var = F.tanh(self.r_log_var1(x))
        r_log_var = self.r_log_var2(r_log_var)
        
        return z_mu, z_log_var, r_mu, r_log_var
    
    
class Decoder(nn.Module):
    def __init__(self, z_dim=10, hidden_dim=4096):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim)
        )
        
        self.block2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim)
        )
        
        self.block3 = nn.Sequential(
            nn.Linear(hidden_dim, 256 * 6 * 6),
            nn.BatchNorm1d(256 * 6 * 6)
        )
        
        self.upsample = nn.Upsample(scale_factor=2)
        
        self.dconv1 = nn.ConvTranspose2d(256, 256, 3, padding=0)
        self.dconv2 = nn.ConvTranspose2d(256, 384, 3, padding=1)
        self.dconv3 = nn.ConvTranspose2d(384, 192, 3, padding=1)
        self.dconv4 = nn.ConvTranspose2d(192, 64, 5, padding=2)
        self.dconv5 = nn.ConvTranspose2d(64, 3, 12, stride=4, padding=4)

    def forward(self, x):
        for block in [self.block1, self.block2, self.block3]:
            x = F.relu(block(x))

        x = x.view(-1, 256, 6, 6)
        x = self.upsample(x)
        
        for i, layer in enumerate([self.dconv1, self.dconv2, self.dconv3, self.dconv4, self.dconv5]):
            x = F.relu(layer(x))
            if i in [2, 3]:
                x = self.upsample(x)
        
        x = F.sigmoid(x)

        return x

In [5]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, z_dim=10):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        self.pz_mu = nn.Linear(1, z_dim)
        self.pz_log_var = nn.Linear(1, 1)
        
        
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        
        return eps * std + mu
        
        
    def forward(self, x):
        z_mu, z_log_var, r_mu, r_log_var = self.encoder(x)
        
        z = self.reparameterize(z_mu, z_log_var)
        r = self.reparameterize(r_mu, r_log_var)
        
        pz_mu = self.pz_mu(r)
        pz_log_var = self.pz_log_var(r)
        
        return self.decoder(z), z_mu, z_log_var, r_mu, r_log_var, r, pz_mu, pz_log_var

    
def vae_loss(orig_image, orig_reg, recon_image, z_mu, z_log_var, r_mu, r_log_var, r, pz_mu, pz_log_var):
    recon_loss = F.mse_loss(recon_image, orig_image)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + z_log_var - pz_log_var - ((z_mu - pz_mu) ** 2 / pz_log_var.exp()) - (z_log_var.exp() / pz_log_var.exp()), dim=1), dim=0)
    reg_loss = torch.mean((0.5 * (r_mu - orig_reg) ** 2 / r_log_var.exp()) + 0.5 * r_log_var)
    
    total_loss = recon_loss + kld_loss + reg_loss 
    
    return total_loss, recon_loss, kld_loss, reg_loss


z_dim = 1024
encoder = Encoder(z_dim=z_dim)
decoder = Decoder(z_dim=z_dim)
vae = VAE(encoder, decoder, z_dim=z_dim)

In [6]:
class KpaPredictor(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
        
    def forward(self, x):
        return self.model(x)
    
    
    def step(self, batch):
        # x: image, y: kpa
        x, y = batch
        recon_image, z_mu, z_log_var, r_mu, r_log_var, r, pz_mu, pz_log_var = self(x)
        total_loss, recon_loss, kld_loss, reg_loss = vae_loss(x, y, recon_image, z_mu, z_log_var, r_mu, r_log_var, r, pz_mu, pz_log_var)
        
        preds = torch.squeeze(r_mu, -1)
        acc = mean_squared_error(preds, y)
        
        return preds, total_loss, acc
    
    
    def training_step(self, batch, batch_idx):
        _, loss, acc = self.step(batch)
        
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_accuracy', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        _, loss, acc = self.step(batch)
        
        self.log('valid_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('valid_accuracy', acc, on_step=False, on_epoch=True, prog_bar=True)
        
    
    def test_step(self, batch, batch_idx):
        _, loss, acc = self.step(batch)
        
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_accuracy', acc, on_step=False, on_epoch=True, prog_bar=True)
    
    
    def predict_step(self, batch, batch_idx):
        preds, _, _ = self.step(batch)
        
        return preds

    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
    
    
    def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
        scheduler.step(epoch=self.current_epoch)

    
    
callbacks = [
    ModelCheckpoint(monitor='valid_loss', save_top_k=3, dirpath='weights/regression_pz_loss', filename='kpa_predictor-{epoch:03d}-{valid_loss:.4f}-{valid_accuracy:.4f}'),
]


kpa_predictor = KpaPredictor(vae)

trainer = pl.Trainer(max_epochs=100, gpus=[0], 
                     enable_progress_bar=True, 
                     callbacks=callbacks, precision=16)

  rank_zero_deprecation(
Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(kpa_predictor, train_dataloader, valid_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type | Params
-------------------------------
0 | model | VAE  | 112 M 
-------------------------------
112 M     Trainable params
0         Non-trainable params
112 M     Total params
224.563   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [9]:
ckpt_fname = ""
kpa_predictor = kpa_predictor.load_from_checkpoint("weights/regression_pz_loss/" + ckpt_fname, model=vae)

trainer.test(kpa_predictor, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

  reg_loss = F.l1_loss(pred_reg, orig_reg)
  reg_loss = F.l1_loss(pred_reg, orig_reg)


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.5028859972953796
        test_loss            4.461840629577637
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 4.461840629577637, 'test_accuracy': 0.5028859972953796}]