In [1]:
from dotenv import load_dotenv
from torchvision import transforms
import ast
from datahandling.transforms import to_numeric_label
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from datahandling.dataloader import get_list_from_txt, extract_annotation_targets
import torch
import os
from dataset.chestxray import ChestXRayDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
debug = ast.literal_eval(os.getenv("DEBUG"))
db_path = os.getenv("DB_PATH")
img_dir_name = os.getenv("IMG_DIR")
class_file_name = os.getenv("CLASSIFICATION_FILE")
train_list = os.getenv("TRAIN_VAL_LIST")
test_list = os.getenv("TEST_LIST")

img_dir = os.path.join(db_path, img_dir_name)
print(debug, db_path, img_dir_name)

True /cluster/home/larsira/tdt4900/databases/chest_xray14 images


In [3]:
annotations_file = os.path.join(db_path, class_file_name)
annotations = pd.read_csv(annotations_file)

train_images = get_list_from_txt(os.path.join(db_path, train_list))
test_images = get_list_from_txt(os.path.join(db_path, test_list))

train_annotations = extract_annotation_targets(annotations, "Image Index", train_images)
test_annotations = extract_annotation_targets(annotations, "Image Index", test_images)

train_annotation_file = os.path.join(db_path, "train.csv")
test_annotation_file = os.path.join(db_path, "test.csv")

In [4]:
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)
    
def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [5]:
IMG_SIZE = 64
BATCH_SIZE = 128

data_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.Lambda(lambda t: t/255),
    transforms.Lambda(lambda t: (t * 2) -1)
])

data_transform = transforms.Compose(data_transforms)

data_train = ChestXRayDataset(train_annotation_file, img_dir, transform=data_transforms, target_transform=to_numeric_label)
train_loader = DataLoader(data_train, BATCH_SIZE)

In [6]:
def show_image(image):
    reverse_transform = transforms.Compose([
        transforms.Lambda(lambda t: (t+1) / 2),
        transforms.Lambda(lambda t: t.permute(1,2,0)),
        transforms.Lambda(lambda t: t*255),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transform(image))

In [7]:
image = next(iter(train_loader))[0]

plt.figure(figsize=(15,15))
plt.axis("off")
num_images=10
stepsize=int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int((idx/stepsize) + 1))
    image, noise = forward_diffusion_sample(image, t)
    show_image(image)

NameError: name 'path' is not defined

In [25]:
for idx, value in enumerate(train_loader):
    test_img, test_lab = value
    print(test_lab)
    break

tensor([[[202, 199, 195,  ...,   5,   2,   0],
         [199, 196, 195,  ...,   5,   2,   0],
         [196, 194, 193,  ...,   5,   2,   0],
         ...,
         [255, 255, 255,  ...,   0,   0,   0],
         [255, 255, 254,  ...,   0,   0,   0],
         [255, 255, 255,  ...,   0,   0,   0]],

        [[202, 199, 195,  ...,   5,   2,   0],
         [199, 196, 195,  ...,   5,   2,   0],
         [196, 194, 193,  ...,   5,   2,   0],
         ...,
         [255, 255, 255,  ...,   0,   0,   0],
         [255, 255, 254,  ...,   0,   0,   0],
         [255, 255, 255,  ...,   0,   0,   0]],

        [[202, 199, 195,  ...,   5,   2,   0],
         [199, 196, 195,  ...,   5,   2,   0],
         [196, 194, 193,  ...,   5,   2,   0],
         ...,
         [255, 255, 255,  ...,   0,   0,   0],
         [255, 255, 254,  ...,   0,   0,   0],
         [255, 255, 255,  ...,   0,   0,   0]]], dtype=torch.uint8)
tensor([[[208, 205, 206,  ..., 204, 215, 139],
         [209, 203, 205,  ..., 202, 210,



tensor([[[ 18,  29,  27,  ...,  16,  22,  17],
         [ 36,  61,  55,  ...,  32,  40,  27],
         [ 34,  58,  52,  ...,  29,  33,  21],
         ...,
         [130, 226, 216,  ..., 139, 150,  89],
         [ 64, 111, 105,  ...,  69,  75,  44],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[ 18,  29,  27,  ...,  16,  22,  17],
         [ 36,  61,  55,  ...,  32,  40,  27],
         [ 34,  58,  52,  ...,  29,  33,  21],
         ...,
         [130, 226, 216,  ..., 139, 150,  89],
         [ 64, 111, 105,  ...,  69,  75,  44],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[ 18,  29,  27,  ...,  16,  22,  17],
         [ 36,  61,  55,  ...,  32,  40,  27],
         [ 34,  58,  52,  ...,  29,  33,  21],
         ...,
         [130, 226, 216,  ..., 139, 150,  89],
         [ 64, 111, 105,  ...,  69,  75,  44],
         [  0,   0,   0,  ...,   0,   0,   0]]], dtype=torch.uint8)
tensor([[[  1,   0,  13,  ...,  10,  11,   6],
         [  1,   0,  26,  ...,  20,  21,