In [7]:
from torch import nn, optim, float32, LongTensor, uint8
import torch
import numpy as np
from torch.nn.functional import mse_loss
#import torch.nn.functional as F
from util.plot_tools import show_and_save
from diffusion.sampling import sample_single_image
from diffusers.optimization import get_constant_schedule_with_warmup
from dataset.chestxray import ChestXRayDataset
from datahandling.dataloader import get_list_from_txt, extract_annotation_targets, extract_unique_labels, extract_n_single_label_images, extract_n_images_from_labels
from datahandling.transforms import to_numeric_label, to_class_int
from torch.utils.data import DataLoader
from torchvision import datasets, models
from models.diffusers import cxr_unet
from diffusion.sampling import sample_single_image
import torchvision.transforms as transforms
import os
import ast
import torchvision
from dotenv import load_dotenv
from diffusers import DDPMScheduler
import timeit
from tqdm import tqdm

In [2]:
load_dotenv()
debug = ast.literal_eval(os.getenv("DEBUG"))
db_path = os.getenv("DB_PATH")
img_dir_name = os.getenv("IMG_DIR")
class_file_name = os.getenv("CLASSIFICATION_FILE")
train_list = os.getenv("TRAIN_VAL_LIST")
test_list = os.getenv("TEST_LIST")

device="cpu"
if torch.cuda.is_available():
    device="cuda"

torch.cuda.empty_cache()
print(device, test_list)

cuda test_list.txt


In [3]:
t_f_n = "small_file.csv"

In [4]:
target_file = os.path.join(db_path, t_f_n)
img_dir = os.path.join(db_path, img_dir_name)
img_size = 128

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = ChestXRayDataset(target_file, img_dir, transform=transform, read_lib="pil", target_transform=to_class_int)
loader = DataLoader(dataset, 8)

print(len(dataset))

1000


In [8]:
learning_rate = 0.001
epochs = 100
inference_steps = 1000
sampling_interval = 10

noise_scheduler = DDPMScheduler(inference_steps)
model = cxr_unet(img_size, len(ChestXRayDataset.target_labels) + 1).to(device)
loss_fn = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=10)

for epoch in tqdm(range(epochs)):
    running_loss = 0
    model.train()
    
    for i, batch in enumerate(tqdm(loader)):

        optimizer.zero_grad()
        
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        
        noise = torch.randn(images.shape).to(device)
        time_steps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (len(images),)).to(device)
        noised_imgs = noise_scheduler.add_noise(images, noise, time_steps)

        pred = model(noised_imgs, time_steps, class_labels=labels.flatten())[0]
        loss = mse_loss(pred, noise)

        # Update everything
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # Increment by 1 to avoid modulo by zero
    if (epoch + 1) % sampling_interval == 0:
        sample_single_image(model, img_size, device, inference_steps, "No Finding", ChestXRayDataset.target_labels)
    

  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/125 [00:00<?, ?it/s][A
  1%|          | 1/125 [00:00<01:12,  1.72it/s][A
  2%|▏         | 2/125 [00:01<01:10,  1.75it/s][A
  2%|▏         | 3/125 [00:01<01:08,  1.77it/s][A
  3%|▎         | 4/125 [00:02<01:08,  1.78it/s][A
  4%|▍         | 5/125 [00:02<01:07,  1.79it/s][A
  5%|▍         | 6/125 [00:03<01:06,  1.78it/s][A
  6%|▌         | 7/125 [00:03<01:05,  1.79it/s][A
  6%|▋         | 8/125 [00:04<01:05,  1.79it/s][A
  7%|▋         | 9/125 [00:05<01:04,  1.79it/s][A
  8%|▊         | 10/125 [00:05<01:03,  1.80it/s][A
  9%|▉         | 11/125 [00:06<01:03,  1.80it/s][A
 10%|▉         | 12/125 [00:06<01:02,  1.80it/s][A
 10%|█         | 13/125 [00:07<01:02,  1.79it/s][A
 11%|█         | 14/125 [00:07<01:02,  1.78it/s][A
 12%|█▏        | 15/125 [00:08<01:01,  1.79it/s][A
 13%|█▎        | 16/125 [00:08<01:00,  1.79it/s][A
 14%|█▎        | 17/125 [00:09<01:00,  1.78it/s][A
 14%|█▍        | 18/125 [00:10<01:00,  1.7

KeyboardInterrupt: 