In [1]:
# The implementation here is based on the butterfly example from the hugging face documentation: https://huggingface.co/docs/diffusers/v0.19.2/tutorials/basic_training
# The implementation follows the  

from torch import nn, optim, float32, LongTensor, uint8
import torch
import numpy as np
import torch.nn.functional as F
from util.plot_tools import show_and_save
from diffusion.sampling import sample_single_image
from dataset.chestxray import ChestXRayDataset
from datahandling.dataloader import get_list_from_txt, extract_annotation_targets, extract_unique_labels, extract_n_single_label_images, extract_n_images_from_labels
from datahandling.transforms import to_numeric_label, to_class_int
from torch.utils.data import DataLoader
from torchvision import datasets, models
from models.diffusers import cxr_unet
import torchvision.transforms as transforms
import os
import ast
import torchvision
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from diffusers import DDPMScheduler
import timeit

  from .autonotebook import tqdm as notebook_tqdm


# Load Environment Variables

In [2]:
load_dotenv()
debug = ast.literal_eval(os.getenv("DEBUG"))
db_path = os.getenv("DB_PATH")
img_dir_name = os.getenv("IMG_DIR")
class_file_name = os.getenv("CLASSIFICATION_FILE")
train_list = os.getenv("TRAIN_VAL_LIST")
test_list = os.getenv("TEST_LIST")

device="cpu"
if torch.cuda.is_available():
    device="cuda"

torch.cuda.empty_cache()
print(device, test_list)

cuda test_list.txt


# Set Constants

In [3]:
NUM_TRAIN_TIMESTEPS = 1000
IMG_SIZE = 128
NUM_GENERATE_IMAGES = 9
BATCH_SIZE=2
WARMUP_STEPS=100
NUM_EPOCHS=100
LEARNING_RATE= 1e-4
MIXED_PRECISION="fp16"
GRADIENT_ACCUMULATION_STEPS=1
NUM_TIMESTEPS = 100

# Load data

In [4]:
img_dir = os.path.join(db_path, img_dir_name)
annotations_file = os.path.join(db_path, class_file_name)
target_file = os.path.join(db_path, "small_file.csv")
annots = extract_n_images_from_labels(annotations_file, 500, [ChestXRayDataset.target_labels[0], "Mass"], target_file, True)

preprocess_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
])

train_dataset = ChestXRayDataset(target_file, img_dir, transform=preprocess_transforms, read_lib="pil", target_transform=to_class_int)
train_loader = DataLoader(train_dataset, BATCH_SIZE)

# Define Model

In [5]:
model = cxr_unet(IMG_SIZE, len(ChestXRayDataset.target_labels))

# Train

In [8]:
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import Accelerator
from tqdm import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=(len(train_loader) * NUM_EPOCHS)
)

accelerator = Accelerator(
    mixed_precision=MIXED_PRECISION,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS
)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler)
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TRAIN_TIMESTEPS)

start = timeit.default_timer()
for epoch in tqdm(range(NUM_EPOCHS), position=0, leave=True):
    model.train()
    train_running_loss=0
    for idx, batch in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        clean_images=batch[0].to(device)
        labels=batch[1].flatten().to(device)
        clean_images.to(device)
        noise = torch.randn(clean_images.shape).to(device)
        last_batch_size=len(clean_images)

        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (last_batch_size,)).to(device)
        print(labels)
        print(timesteps)
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        with accelerator.accumulate(model):
            noise_pred = model(noisy_images, timesteps, return_dict=False, class_labels=labels)[0]
            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)

            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
        train_running_loss += loss.item()
    train_loss = train_running_loss / (idx+1)

    train_learning_rate = lr_scheduler.get_last_lr()[0]
    print("-"*30)
    print(f"train loss epoch: {epoch+1}: {train_loss:.4f}")
    print(f"Train Learning Rate EPOCH: {epoch+1}: {train_learning_rate}")
    if epoch%10 == 0:
        img_1 = sample_single_image(model, IMG_SIZE, device, NUM_TRAIN_TIMESTEPS, ChestXRayDataset.target_labels[0], ChestXRayDataset.target_labels, epoch)
        show_and_save(img_1, "Class: {}, Epoch: {}".format(ChestXRayDataset.target_labels[0], epoch), "result/{}_{}".format(ChestXRayDataset.target_labels[0], epoch))
        img_2 = sample_single_image(model, IMG_SIZE, device, NUM_TRAIN_TIMESTEPS, "Mass", ChestXRayDataset.target_labels, epoch)
        show_and_save(img_2, "Class: {}, Epoch: {}".format("Mass", epoch), "result/{}_{}".format("Mass", epoch))
        # sample_image_generation(model, noise_scheduler, NUM_GENERATE_IMAGES, RANDOM_SEED, NUM_TIMESTEPS)
    print("-"*30)

  0%|          | 0/500 [00:00<?, ?it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([880, 427], device='cuda:0')


  0%|          | 1/500 [00:00<03:32,  2.34it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([494, 522], device='cuda:0')


  0%|          | 2/500 [00:00<03:15,  2.55it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([416, 628], device='cuda:0')


  1%|          | 3/500 [00:01<03:06,  2.66it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([614,  95], device='cuda:0')


  1%|          | 4/500 [00:01<03:02,  2.72it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([ 16, 804], device='cuda:0')


  1%|          | 5/500 [00:01<02:59,  2.76it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([664,  99], device='cuda:0')


  1%|          | 6/500 [00:02<02:58,  2.77it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([965, 336], device='cuda:0')


  1%|▏         | 7/500 [00:02<03:17,  2.49it/s]
  0%|          | 0/100 [00:02<?, ?it/s]

tensor([0, 0], device='cuda:0', dtype=torch.int32)
tensor([842, 711], device='cuda:0')




KeyboardInterrupt



In [None]:
sample_single_image(model, IMG_SIZE, device, NUM_TRAIN_TIMESTEPS, ChestXRayDataset.target_labels[0], ChestXRayDataset.target_labels, 100)

In [None]:
sample_single_image(model, IMG_SIZE, device, NUM_TRAIN_TIMESTEPS, "Mass", ChestXRayDataset.target_labels, 1001)